Commit 9685fed2 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 11b83234
...@@ -16,6 +16,102 @@ namespace ck { ...@@ -16,6 +16,102 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
......
...@@ -18,9 +18,6 @@ namespace ck { ...@@ -18,9 +18,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
......
...@@ -20,9 +20,9 @@ namespace tensor_operation { ...@@ -20,9 +20,9 @@ namespace tensor_operation {
namespace device { namespace device {
/* /*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * \brief Wrapper function of GridwiseGemm::Run to realize Split-K GEMM.
* *
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3 * \see \link device_gemm_xdl_splitk.hpp kernel_gemm_xdl_splitk
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_v2r3( kernel_gemm_xdl_splitk(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -616,7 +616,7 @@ struct DeviceGemmXdlSplitK ...@@ -616,7 +616,7 @@ struct DeviceGemmXdlSplitK
if(has_main_k0_block_loop && tail_has_main_k0_block_loop) if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdl_splitk<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -635,7 +635,7 @@ struct DeviceGemmXdlSplitK ...@@ -635,7 +635,7 @@ struct DeviceGemmXdlSplitK
} }
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop) else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdl_splitk<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -654,7 +654,7 @@ struct DeviceGemmXdlSplitK ...@@ -654,7 +654,7 @@ struct DeviceGemmXdlSplitK
} }
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop) else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdl_splitk<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -673,7 +673,7 @@ struct DeviceGemmXdlSplitK ...@@ -673,7 +673,7 @@ struct DeviceGemmXdlSplitK
} }
else else
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdl_splitk<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -698,7 +698,7 @@ struct DeviceGemmXdlSplitK ...@@ -698,7 +698,7 @@ struct DeviceGemmXdlSplitK
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
const auto kernel = ck::kernel_batched_gemm_xdlops_v2r3< const auto kernel = ck::kernel_gemm_xdl_splitk<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -732,7 +732,7 @@ struct DeviceGemmXdlSplitK ...@@ -732,7 +732,7 @@ struct DeviceGemmXdlSplitK
} }
else else
{ {
const auto kernel = ck::kernel_batched_gemm_xdlops_v2r3< const auto kernel = ck::kernel_gemm_xdl_splitk<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
......
...@@ -22,8 +22,6 @@ namespace device { ...@@ -22,8 +22,6 @@ namespace device {
/* /*
* \brief Wrapper function of GridwiseGemm::Run to realize a customized BatchedGemm for splitK. * \brief Wrapper function of GridwiseGemm::Run to realize a customized BatchedGemm for splitK.
* *
* The main difference from \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
* is that there are 2 different tensor descriptors for matrix A and B.
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
...@@ -209,7 +207,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -209,7 +207,6 @@ struct DeviceGemmXdlSplitKCShuffle
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_tuple(make_right_pad_transform(MRaw, MPad),
...@@ -273,7 +270,8 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -273,7 +270,8 @@ struct DeviceGemmXdlSplitKCShuffle
if constexpr(GemmSpec == GemmSpecialization::NPadding || if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ // pad both N and K {
// pad both N and K
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad), make_tuple(make_right_pad_transform(NRaw, NPad),
...@@ -290,8 +288,9 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -290,8 +288,9 @@ struct DeviceGemmXdlSplitKCShuffle
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else // pad K, but not N else
{ {
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
...@@ -395,9 +394,9 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -395,9 +394,9 @@ struct DeviceGemmXdlSplitKCShuffle
} }
private: private:
// TODO: should they be long_index_t?
index_t BatchStrideA_; index_t BatchStrideA_;
index_t BatchStrideB_; index_t BatchStrideB_;
// index_t BatchStrideC_; // always zero
}; };
// GridwiseGemm // GridwiseGemm
...@@ -476,8 +475,11 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -476,8 +475,11 @@ struct DeviceGemmXdlSplitKCShuffle
GetActualBatchAndKSplitted<AK1>(KRaw, k_batch); GetActualBatchAndKSplitted<AK1>(KRaw, k_batch);
const auto actual_batch_and_ksplitted_B = const auto actual_batch_and_ksplitted_B =
GetActualBatchAndKSplitted<BK1>(KRaw, k_batch); GetActualBatchAndKSplitted<BK1>(KRaw, k_batch);
assert(actual_batch_and_ksplitted_A.first == actual_batch_and_ksplitted_B.first); assert(actual_batch_and_ksplitted_A.first == actual_batch_and_ksplitted_B.first);
BatchCount_ = actual_batch_and_ksplitted_A.first;
BatchCount_ = actual_batch_and_ksplitted_A.first;
const auto AKSplitted = actual_batch_and_ksplitted_A.second; const auto AKSplitted = actual_batch_and_ksplitted_A.second;
const auto BKSplitted = actual_batch_and_ksplitted_B.second; const auto BKSplitted = actual_batch_and_ksplitted_B.second;
......
...@@ -17,6 +17,88 @@ namespace ck { ...@@ -17,6 +17,88 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_descs[group_id].a_ptr,
gemm_descs[group_id].b_ptr,
gemm_descs[group_id].c_ptr,
p_shared,
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
#else
ignore = gemm_descs;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
......
...@@ -66,11 +66,6 @@ __global__ void ...@@ -66,11 +66,6 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
......
...@@ -67,184 +67,6 @@ __global__ void ...@@ -67,184 +67,6 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
#else
ignore = gemm_desc_;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment