Commit 66f8e4bb authored by Adam Osewski's avatar Adam Osewski
Browse files

Support A/B/C elementwise ops.

parent b9ab9f4b
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -26,13 +26,20 @@ namespace device { ...@@ -26,13 +26,20 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(
const index_t group_count) const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op = AElementwiseOperation{},
const BElementwiseOperation b_element_op = BElementwiseOperation{},
const CElementwiseOperation c_element_op = CElementwiseOperation{})
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
...@@ -64,10 +71,16 @@ __global__ void ...@@ -64,10 +71,16 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_, gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_); gemm_desc_ptr[group_id].block_2_ctile_map_,
a_element_op,
b_element_op,
c_element_op);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment