Commit e739c577 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

refactor DeviceGemmXdlSplitK to support arbitrary K; remove template parameter...

refactor DeviceGemmXdlSplitK to support arbitrary K; remove template parameter A/B/CGridDesc from gridwise_gemm_v2r3; start rewriting conv2d_backward
parent 22d63c05
#pragma once
#include "tuple.hpp"
#include "tensor_adaptor.hpp"
#include "multi_index_transform_helper.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct BatchedGemmUtil
{
template <index_t MPerBlock, index_t NPerBlock>
static constexpr auto
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01=1, index_t N01=1)
{
constexpr auto M1 = MPerBlock;
constexpr auto N1 = NPerBlock;
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "batched_gemm_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv_backward_weight.hpp" #include "device_conv_backward_weight.hpp"
...@@ -12,6 +13,7 @@ ...@@ -12,6 +13,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp" #include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "batched_gemm_util.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -25,33 +27,34 @@ template <typename InDataType, ...@@ -25,33 +27,34 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t BlockSize, index_t NumGemmKPrefetchStage,
ck::index_t MPerBlock, index_t BlockSize,
ck::index_t NPerBlock, index_t MPerBlock,
ck::index_t K0PerBlock, index_t NPerBlock,
ck::index_t K1, index_t K0PerBlock,
index_t AK1,
ck::index_t MPerXdl, ck::index_t MPerXdl,
ck::index_t NPerXdl, ck::index_t NPerXdl,
ck::index_t MXdlPerWave, ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave, ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsAddExtraM, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsAddExtraN, bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWeight<InElementwiseOperation, : public DeviceConvBwdWeight<InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
...@@ -92,7 +95,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -92,7 +95,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
ck::index_t batch_k) ck::index_t k_batch)
{ {
using namespace ck; using namespace ck;
...@@ -117,35 +120,40 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -117,35 +120,40 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadW = input_right_pads[1];
const index_t GemmKTotal = N * Ho * Wo; const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X * Y; const index_t GemmN = C * X * Y;
const index_t GemmAKPerBatch = GemmAK0 * GemmAK1Number;
const index_t GemmBKPerBatch = GemmBK0 * GemmBK1Number;
const index_t GemmKBatch = batch_k; const index_t GemmAK0 =
const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmAK1Number * K0PerBlock * k_batch) *
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; const index_t GemmBK0 =
math::integer_divide_ceil(GemmKTotal, GemmBK1Number * K0PerBlock * k_batch) *
K0PerBlock;
const index_t GemmAKPad = GemmKBatch * GemmAK0 * GemmAK1Number;
const index_t GemmBKPad = GemmKBatch * GemmBK0 * GemmBK1Number;
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmk_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc, out_gemmk_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_tuple(make_right_pad_transform(GemmKTotal, GemmAKPad - GemmKTotal),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(GemmAK0, GemmAK1Number)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor // B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
...@@ -176,24 +184,24 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -176,24 +184,24 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc, in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_tuple(make_right_pad_transform(GemmKTotal, GemmBKPad - GemmKTotal),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(GemmBK0, GemmBK1Number)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmm_gemmn_grid_desc);
} }
...@@ -205,93 +213,97 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -205,93 +213,97 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, GemmAccDataType,
CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, KPerBlock,
MPerXdl, AK1,
NPerXdl, BK1,
K1, MPerXDL,
NPerXDL,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false,
ABlockLdsAddExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false,
BBlockLdsAddExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CShuffleBlockTransferScalarPerVector_NPerBlock > ;
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemmAtomicAdd = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, GemmAccDataType,
CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, KPerBlock,
MPerXdl, AK1,
NPerXdl, BK1,
K1, MPerXDL,
NPerXDL,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false,
ABlockLdsAddExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false,
BBlockLdsAddExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// Argument CShuffleBlockTransferScalarPerVector_NPerBlock > ;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
...@@ -316,8 +328,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -316,8 +328,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
: p_a_grid_{p_out_grid}, : p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid}, p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid}, p_c_grid_{p_wei_grid},
a_grid_desc_kbatch_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
...@@ -349,31 +361,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -349,31 +361,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
input_right_pads, input_right_pads,
k_batch_); k_batch_);
a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, if(GridwiseGemm::CheckValidity(
b_grid_desc_kbatch_k0_n_k1_, a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
c_grid_desc_m_n_,
M01_,
N01_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(
k_batch_,
c_grid_desc_m_n_.GetLength(I0),
c_grid_desc_m_n_.GetLength(I1),
M01_,
N01_);
} }
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; index_t k_batch_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
BatchedGemmUtil::ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -399,17 +416,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -399,17 +416,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
void ShowInfo(const Argument& arg) void ShowInfo(const Argument& arg)
{ {
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" std::cout << "k_batch = " << arg.BatchCount_ << "\n";
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
...@@ -418,8 +432,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -418,8 +432,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
ShowInfo(arg); ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_)) arg.N01_))
...@@ -427,10 +441,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -427,10 +441,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
} }
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto k_batch = arg.k_batch_;
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * k_batch;
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
...@@ -448,8 +463,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -448,8 +463,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -457,7 +472,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -457,7 +472,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
if(kbatch > 1 || nrepeat <= 0) if(k_batch > 1 || nrepeat <= 0)
{ {
hipGetErrorString(hipMemset( hipGetErrorString(hipMemset(
arg.p_c_grid_, arg.p_c_grid_,
...@@ -472,8 +487,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -472,8 +487,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -484,36 +499,38 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -484,36 +499,38 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(kbatch == 1) if(k_batch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, AElementwiseOperation,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, BElementwiseOperation,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, CElementwiseOperation,
OutElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1,
InElementwiseOperation, DeviceOp::BGridDesc_BK0_N_BK1,
WeiElementwiseOperation, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<DeviceOp::Block2CTileMap>, ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true>; true>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemmAtomicAdd, GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, AElementwiseOperation,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, BElementwiseOperation,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, CElementwiseOperation,
OutElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1,
InElementwiseOperation, DeviceOp::BGridDesc_BK0_N_BK1,
WeiElementwiseOperation, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<DeviceOp::Block2CTileMap>, ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true>; true>;
Run(kernel); Run(kernel);
...@@ -521,36 +538,38 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -521,36 +538,38 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
} }
else else
{ {
if(kbatch == 1) if(k_batch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, AElementwiseOperation,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, BElementwiseOperation,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, CElementwiseOperation,
OutElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1,
InElementwiseOperation, DeviceOp::BGridDesc_BK0_N_BK1,
WeiElementwiseOperation, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<DeviceOp::Block2CTileMap>, ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false>; false>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r4r2< const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemmAtomicAdd, GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, AElementwiseOperation,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, BElementwiseOperation,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, CElementwiseOperation,
OutElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1,
InElementwiseOperation, DeviceOp::BGridDesc_BK0_N_BK1,
WeiElementwiseOperation, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<DeviceOp::Block2CTileMap>, ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false>; false>;
Run(kernel); Run(kernel);
...@@ -583,14 +602,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -583,14 +602,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
} }
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{ {
return false; return false;
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_); arg.N01_);
......
#ifndef DEVICE_GEMM_XDL_SPLITK_HPP #ifndef DEVICE_GEMM_XDL_SPLITK_HPP
#define DEVICE_GEMM_XDL_SPLITK_HPP #define DEVICE_GEMM_XDL_SPLITK_HPP
#include <cstdio>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <utility>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm.hpp" #include "device_gemm.hpp"
...@@ -22,13 +24,16 @@ template <typename GridwiseGemm, ...@@ -22,13 +24,16 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename AGridDesc_K0_M_K1_Tail,
typename BGridDesc_K0_N_K1_Tail,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool TailHasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -40,6 +45,8 @@ __global__ void ...@@ -40,6 +45,8 @@ __global__ void
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const AGridDesc_K0_M_K1_Tail a_grid_desc_k0_m_k1_tail,
const BGridDesc_K0_N_K1_Tail b_grid_desc_k0_n_k1_tail,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -61,17 +68,34 @@ __global__ void ...@@ -61,17 +68,34 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, if(g_idx < batch_count - 1)
p_b_grid + b_batch_offset, {
p_c_grid + c_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_shared, p_b_grid + b_batch_offset,
a_grid_desc_k0_m_k1, p_c_grid + c_batch_offset,
b_grid_desc_k0_n_k1, p_shared,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, a_grid_desc_k0_m_k1,
a_element_op, b_grid_desc_k0_n_k1,
b_element_op, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_element_op, a_element_op,
block_2_ctile_map); b_element_op,
c_element_op,
block_2_ctile_map);
}
else
{
GridwiseGemm::template Run<TailHasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_k0_m_k1_tail,
b_grid_desc_k0_n_k1_tail,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -79,6 +103,8 @@ __global__ void ...@@ -79,6 +103,8 @@ __global__ void
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = a_grid_desc_k0_m_k1_tail;
ignore = b_grid_desc_k0_n_k1_tail;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -133,16 +159,21 @@ struct DeviceGemmXdlSplitK ...@@ -133,16 +159,21 @@ struct DeviceGemmXdlSplitK
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto GetKPad(index_t K, index_t KBatch) // static constexpr index_t Getk
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{ {
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
const index_t KPad = KBatch * K0 * K1; const index_t KSplitted = K0 * K1;
return KPad; const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
// return std::make_pair<index_t, index_t>(actual_batch, KSplitted);
return std::make_pair(actual_batch, KSplitted);
} }
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{ {
assert(K % K1 == 0); assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
...@@ -181,7 +212,7 @@ struct DeviceGemmXdlSplitK ...@@ -181,7 +212,7 @@ struct DeviceGemmXdlSplitK
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{ {
assert(K % K1 == 0); assert(K % (K1 * K0PerBlock) == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
...@@ -218,6 +249,95 @@ struct DeviceGemmXdlSplitK ...@@ -218,6 +249,95 @@ struct DeviceGemmXdlSplitK
} }
} }
static auto MakeAGridDescriptor_K0_M_K1_Tail(index_t M, index_t K, index_t StrideA)
{
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1_Tail(index_t K, index_t N, index_t StrideB)
{
const index_t KPad = math::integer_divide_ceil(K, K1 * K0PerBlock) * K1 * K0PerBlock;
const index_t K0 = KPad / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]() {
...@@ -253,9 +373,11 @@ struct DeviceGemmXdlSplitK ...@@ -253,9 +373,11 @@ struct DeviceGemmXdlSplitK
} }
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using AGridDesc_K0_M_K1_Tail = decltype(MakeAGridDescriptor_K0_M_K1_Tail(1, 1, 1));
using BGridDesc_K0_N_K1_Tail = decltype(MakeBGridDescriptor_K0_N_K1_Tail(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count, static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
...@@ -330,9 +452,9 @@ struct DeviceGemmXdlSplitK ...@@ -330,9 +452,9 @@ struct DeviceGemmXdlSplitK
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, // AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, // BGridDesc_K0_N_K1,
CGridDesc_M_N, // CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -390,27 +512,44 @@ struct DeviceGemmXdlSplitK ...@@ -390,27 +512,44 @@ struct DeviceGemmXdlSplitK
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(k_batch), BatchCount_(k_batch),
has_tail_(false),
compute_ptr_offset_of_batch_{0, 0}, compute_ptr_offset_of_batch_{0, 0},
block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
const auto KPad = GetKPad(K, k_batch); const auto actual_batch_and_ksplitted = GetActualBatchAndKSplitted(K, k_batch);
assert(KPad % k_batch == 0); BatchCount_ = actual_batch_and_ksplitted.first;
const auto KSplitted = KPad / k_batch; const auto KSplitted = actual_batch_and_ksplitted.second;
a_grid_desc_k0_m_k1_ = a_grid_desc_k0_m_k1_ =
DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1(M, KSplitted, StrideA); DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1(M, KSplitted, StrideA);
b_grid_desc_k0_n_k1_ = b_grid_desc_k0_n_k1_ =
DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KSplitted, N, StrideB); DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KSplitted, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity( bool is_valid = GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_);
if(K != KSplitted * BatchCount_)
{
has_tail_ = true;
const auto KTail = K - KSplitted * (BatchCount_ - 1);
a_grid_desc_k0_m_k1_tail_ =
DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1_Tail(M, KTail, StrideA);
b_grid_desc_k0_n_k1_tail_ =
DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1_Tail(KTail, N, StrideB);
is_valid &= GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_tail_,
b_grid_desc_k0_n_k1_tail_,
c_grid_desc_m_n_,
M01_,
N01_);
}
if(is_valid)
{ {
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
...@@ -441,7 +580,7 @@ struct DeviceGemmXdlSplitK ...@@ -441,7 +580,7 @@ struct DeviceGemmXdlSplitK
compute_ptr_offset_of_batch_ = compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride}; ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = MakeBlock2CTileMap(k_batch, c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount_, c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -450,8 +589,11 @@ struct DeviceGemmXdlSplitK ...@@ -450,8 +589,11 @@ struct DeviceGemmXdlSplitK
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t BatchCount_;
bool has_tail_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
AGridDesc_K0_M_K1_Tail a_grid_desc_k0_m_k1_tail_;
BGridDesc_K0_N_K1_Tail b_grid_desc_k0_n_k1_tail_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
...@@ -482,13 +624,36 @@ struct DeviceGemmXdlSplitK ...@@ -482,13 +624,36 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
if(arg.has_tail_)
{
std::cout << "arg.a_grid_desc_k0_m_k1_tail_{"
<< arg.a_grid_desc_k0_m_k1_tail_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_tail_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_tail_{"
<< arg.b_grid_desc_k0_n_k1_tail_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_tail_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_tail_.GetLength(I2) << "}" << std::endl;
}
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, bool is_valid = GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_)) arg.N01_);
if(arg.has_tail_)
{
is_valid &= GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
if(!is_valid)
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
...@@ -497,79 +662,246 @@ struct DeviceGemmXdlSplitK ...@@ -497,79 +662,246 @@ struct DeviceGemmXdlSplitK
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(arg.has_tail_)
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
GridwiseGemm, const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
ADataType, // TODO: distiguish A/B datatype const auto K0_tail = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
CDataType, const bool tail_has_main_k0_block_loop =
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>, GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
AElementwiseOperation, {
BElementwiseOperation, const auto kernel = kernel_batched_gemm_xdlops_v2r3<
CElementwiseOperation, GridwiseGemm,
ComputePtrOffsetOfStridedBatch, ADataType, // TODO: distiguish A/B datatype
remove_reference_t<Block2CTileMap>, CDataType,
true>; remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
ave_time = launch_and_time_kernel(kernel, remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
nrepeat, remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
dim3(grid_size), remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
dim3(BlockSize), AElementwiseOperation,
0, BElementwiseOperation,
arg.p_a_grid_, CElementwiseOperation,
arg.p_b_grid_, ComputePtrOffsetOfStridedBatch,
arg.p_c_grid_, remove_reference_t<Block2CTileMap>,
arg.BatchCount_, true,
arg.a_grid_desc_k0_m_k1_, true>;
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, ave_time = launch_and_time_kernel(kernel,
arg.a_element_op_, nrepeat,
arg.b_element_op_, dim3(grid_size),
arg.c_element_op_, dim3(BlockSize),
arg.compute_ptr_offset_of_batch_, 0,
arg.block_2_ctile_map_); arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
else if(has_main_k0_block_loop && !tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
else if(!has_main_k0_block_loop && tail_has_main_k0_block_loop)
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1_Tail>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1_Tail>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
} }
else else
{ {
const auto kernel = kernel_batched_gemm_xdlops_v2r3< const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
GridwiseGemm, const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
ADataType, // TODO: distiguish A/B datatype
CDataType, if(has_main_k0_block_loop)
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>, {
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>, const auto kernel = ck::kernel_batched_gemm_xdlops_v2r3<
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, GridwiseGemm,
AElementwiseOperation, ADataType, // TODO: distiguish A/B datatype
BElementwiseOperation, CDataType,
CElementwiseOperation, remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
ComputePtrOffsetOfStridedBatch, remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<Block2CTileMap>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
false>; AElementwiseOperation,
BElementwiseOperation,
ave_time = launch_and_time_kernel(kernel, CElementwiseOperation,
nrepeat, ComputePtrOffsetOfStridedBatch,
dim3(grid_size), remove_reference_t<Block2CTileMap>,
dim3(BlockSize), true>;
0,
arg.p_a_grid_, ave_time = launch_and_time_kernel(kernel,
arg.p_b_grid_, nrepeat,
arg.p_c_grid_, dim3(grid_size),
arg.BatchCount_, dim3(BlockSize),
arg.a_grid_desc_k0_m_k1_, 0,
arg.b_grid_desc_k0_n_k1_, arg.p_a_grid_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.p_b_grid_,
arg.a_element_op_, arg.p_c_grid_,
arg.b_element_op_, arg.BatchCount_,
arg.c_element_op_, arg.a_grid_desc_k0_m_k1_,
arg.compute_ptr_offset_of_batch_, arg.b_grid_desc_k0_n_k1_,
arg.block_2_ctile_map_); arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = ck::kernel_batched_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
} }
return ave_time; return ave_time;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "batched_gemm_util.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -343,43 +344,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -343,43 +344,6 @@ struct DeviceGemmXdlSplitKCShuffle
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
...@@ -455,7 +419,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -455,7 +419,7 @@ struct DeviceGemmXdlSplitKCShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock>;
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -529,7 +493,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -529,7 +493,7 @@ struct DeviceGemmXdlSplitKCShuffle
compute_ptr_offset_of_batch_ = compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride}; ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount_, c_grid_desc_m_n_, 1, 1); block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1));
} }
} }
......
...@@ -250,9 +250,6 @@ template <index_t BlockSize, ...@@ -250,9 +250,6 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -361,6 +358,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -361,6 +358,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
...@@ -420,9 +418,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -420,9 +418,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return true; return true;
} }
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -439,9 +439,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -439,9 +439,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -491,11 +494,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -491,11 +494,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) MakeDefaultBlock2CTileMap(index_t M, index_t N, index_t M01, index_t N01)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -525,11 +525,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -525,11 +525,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return cblockid_to_m0_n0_block_cluster_adaptor; return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = // using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); // decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(1, 1, 1, 1));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainK0BlockLoop,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
......
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R4R2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CMNGridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB),
c_block_size * sizeof(FloatC));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
index_t M01,
index_t N01)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
return transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KBatch),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
return make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerXDL>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto a_b_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto b_b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
}
}();
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
static_assert(N1 == NWave, "");
static_assert(M2 * M3 * M4 == MPerXDL, "");
static_assert(N2 == NPerXDL, "");
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
M1,
M2,
M3,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform(I0), // freeze nblock
make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
N1,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
// LDS to global
c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_forward_step);
}
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
}
});
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
}
});
}
}
}; // namespace ck
} // namespace ck
#endif
# device_gemm_instance # device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; # device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; # device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; # device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; # device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; # device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; # device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; # device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; # device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; # device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; # device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; # device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; # device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
) )
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
......
...@@ -24,40 +24,40 @@ include_directories(BEFORE ...@@ -24,40 +24,40 @@ include_directories(BEFORE
set(PROFILER_SOURCE set(PROFILER_SOURCE
src/profiler.cpp src/profiler.cpp
src/profile_gemm.cpp src/profile_gemm.cpp
src/profile_gemm_bias_2d.cpp # src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_relu.cpp # src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu_add.cpp # src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp # src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp # src/profile_batched_gemm.cpp
src/profile_conv_fwd_bias_relu.cpp # src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp # src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp # src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_convnd_fwd.cpp # src/profile_convnd_fwd.cpp
src/profile_convnd_bwd_data.cpp # src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp # src/profile_reduce.cpp
src/profile_grouped_gemm.cpp # src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp # src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp # src/profile_batched_gemm_reduce.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE conv_fwd_util) # target_link_libraries(ckProfiler PRIVATE conv_fwd_util)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) # target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) # target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) # target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) # target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) # target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) # target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) # target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) # target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) # target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) # target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) # target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) # target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) # target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) # target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) # target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) # target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
...@@ -23,56 +23,56 @@ using DeviceGemmNoOpPtr = ...@@ -23,56 +23,56 @@ using DeviceGemmNoOpPtr =
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( // void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( // void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( // void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( // void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( // void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( // void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( // void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( // void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( // void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); // std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); // void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance } // namespace device_gemm_instance
} // namespace device } // namespace device
...@@ -171,11 +171,11 @@ void profile_gemm_impl(int do_verification, ...@@ -171,11 +171,11 @@ void profile_gemm_impl(int do_verification,
} }
else else
{ {
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); // add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
//
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); // add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
} }
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
...@@ -189,11 +189,11 @@ void profile_gemm_impl(int do_verification, ...@@ -189,11 +189,11 @@ void profile_gemm_impl(int do_verification,
} }
else else
{ {
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); // add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
//
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); // add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
} }
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
...@@ -207,11 +207,11 @@ void profile_gemm_impl(int do_verification, ...@@ -207,11 +207,11 @@ void profile_gemm_impl(int do_verification,
} }
else else
{ {
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); // add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
//
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); // add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
} }
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
...@@ -225,158 +225,158 @@ void profile_gemm_impl(int do_verification, ...@@ -225,158 +225,158 @@ void profile_gemm_impl(int do_verification,
} }
else else
{ {
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); // add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
//
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); // add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
}
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
} }
} }
} }
else if constexpr(is_same<ADataType, ck::bhalf_t>::value && // else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<BDataType, ck::bhalf_t>::value && // is_same<CDataType, half_t>::value)
is_same<CDataType, ck::bhalf_t>::value) // {
{ // if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && // is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && // is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // {
{ // if(KBatch > 1)
ck::tensor_operation::device::device_gemm_instance:: // {
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs); // ck::tensor_operation::device::device_gemm_instance::
} // add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && // }
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && // else
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // {
{ // ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::device_gemm_instance:: // add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs); //
} // ck::tensor_operation::device::device_gemm_instance::
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && // add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && // }
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // }
{ // else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
ck::tensor_operation::device::device_gemm_instance:: // is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs); // is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
} // {
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && // if(KBatch > 1)
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && // {
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // ck::tensor_operation::device::device_gemm_instance::
{ // add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance:: // }
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs); // else
} // {
} // ck::tensor_operation::device::device_gemm_instance::
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value && // add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
is_same<CDataType, int8_t>::value) //
{ // ck::tensor_operation::device::device_gemm_instance::
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && // add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && //
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // ck::tensor_operation::device::device_gemm_instance::
{ // add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_gemm_instance:: // }
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); // }
} // else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && // is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && // is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // {
{ // if(KBatch > 1)
ck::tensor_operation::device::device_gemm_instance:: // {
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); // ck::tensor_operation::device::device_gemm_instance::
} // add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && // }
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && // else
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // {
{ // ck::tensor_operation::device::device_gemm_instance::
ck::tensor_operation::device::device_gemm_instance:: // add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); //
} // ck::tensor_operation::device::device_gemm_instance::
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && // add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && // }
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) // }
{ // else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
ck::tensor_operation::device::device_gemm_instance:: // is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); // is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
} // {
} // if(KBatch > 1)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
// }
// else
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
//
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
// }
// }
// }
// else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
// is_same<BDataType, ck::bhalf_t>::value &&
// is_same<CDataType, ck::bhalf_t>::value)
// {
// if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs);
// }
// else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs);
// }
// else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs);
// }
// else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs);
// }
// }
// else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
// is_same<CDataType, int8_t>::value)
// {
// if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs);
// }
// else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs);
// }
// else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs);
// }
// else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
// {
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs);
// }
// }
if(gemm_ptrs.size() <= 0) if(gemm_ptrs.size() <= 0)
{ {
......
...@@ -7,19 +7,19 @@ ...@@ -7,19 +7,19 @@
#include "profile_convnd_fwd.hpp" #include "profile_convnd_fwd.hpp"
int profile_gemm(int, char*[]); int profile_gemm(int, char*[]);
int profile_gemm_bias_2d(int, char*[]); // int profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_relu(int, char*[]); // int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]); // int profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_reduce(int, char*[]); // int profile_gemm_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]); // int profile_batched_gemm(int, char*[]);
int profile_grouped_gemm(int, char*[]); // int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]); // int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]); // int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); // int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int profile_convnd_bwd_data(int, char*[], int); // int profile_convnd_bwd_data(int, char*[], int);
int profile_reduce(int, char*[]); // int profile_reduce(int, char*[]);
int profile_conv_bwd_weight(int, char*[]); // int profile_conv_bwd_weight(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]); // int profile_batched_gemm_reduce(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -27,70 +27,70 @@ int main(int argc, char* argv[]) ...@@ -27,70 +27,70 @@ int main(int argc, char* argv[])
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
} }
else if(strcmp(argv[1], "gemm_bias_2d") == 0) // else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{ // {
return profile_gemm_bias_2d(argc, argv); // return profile_gemm_bias_2d(argc, argv);
} // }
else if(strcmp(argv[1], "gemm_bias_relu") == 0) // else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{ // {
return profile_gemm_bias_relu(argc, argv); // return profile_gemm_bias_relu(argc, argv);
} // }
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) // else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{ // {
return profile_gemm_bias_relu_add(argc, argv); // return profile_gemm_bias_relu_add(argc, argv);
} // }
else if(strcmp(argv[1], "gemm_reduce") == 0) // else if(strcmp(argv[1], "gemm_reduce") == 0)
{ // {
return profile_gemm_reduce(argc, argv); // return profile_gemm_reduce(argc, argv);
} // }
else if(strcmp(argv[1], "batched_gemm") == 0) // else if(strcmp(argv[1], "batched_gemm") == 0)
{ // {
return profile_batched_gemm(argc, argv); // return profile_batched_gemm(argc, argv);
} // }
else if(strcmp(argv[1], "batched_gemm_reduce") == 0) // else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{ // {
return profile_batched_gemm_reduce(argc, argv); // return profile_batched_gemm_reduce(argc, argv);
} // }
else if(strcmp(argv[1], "grouped_gemm") == 0) // else if(strcmp(argv[1], "grouped_gemm") == 0)
{ // {
profile_grouped_gemm(argc, argv); // profile_grouped_gemm(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd") == 0) // else if(strcmp(argv[1], "conv_fwd") == 0)
{ // {
return ck::profiler::profile_convnd_fwd(argc, argv); // return ck::profiler::profile_convnd_fwd(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) // else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{ // {
return profile_conv_fwd_bias_relu(argc, argv); // return profile_conv_fwd_bias_relu(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) // else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{ // {
return profile_conv_fwd_bias_relu_add(argc, argv); // return profile_conv_fwd_bias_relu_add(argc, argv);
} // }
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) // else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
{ // {
return profile_conv_fwd_bias_relu_atomic_add(argc, argv); // return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
} // }
else if(strcmp(argv[1], "conv1d_bwd_data") == 0) // else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
{ // {
return profile_convnd_bwd_data(argc, argv, 1); // return profile_convnd_bwd_data(argc, argv, 1);
} // }
else if(strcmp(argv[1], "conv2d_bwd_data") == 0) // else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
{ // {
return profile_convnd_bwd_data(argc, argv, 2); // return profile_convnd_bwd_data(argc, argv, 2);
} // }
else if(strcmp(argv[1], "conv3d_bwd_data") == 0) // else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
{ // {
return profile_convnd_bwd_data(argc, argv, 3); // return profile_convnd_bwd_data(argc, argv, 3);
} // }
else if(strcmp(argv[1], "reduce") == 0) // else if(strcmp(argv[1], "reduce") == 0)
{ // {
return profile_reduce(argc, argv); // return profile_reduce(argc, argv);
} // }
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) // else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
{ // {
return profile_conv_bwd_weight(argc, argv); // return profile_conv_bwd_weight(argc, argv);
} // }
else else
{ {
// clang-format off // clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment