Commit 326d331c authored by myamlak's avatar myamlak
Browse files

Merge remote-tracking branch 'origin/develop' into myamlak/cgemm

parents 5fd5daab ac543313
......@@ -11,7 +11,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
namespace ck {
namespace tensor_operation {
......@@ -81,6 +81,22 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;
static constexpr auto N1Number = K1Number;
// Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
// M1 & M0
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
static constexpr auto ABlockLdsM1Padding = 4;
// N1 & N0
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
static constexpr auto BBlockLdsN1Padding = 4;
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
......@@ -125,27 +141,51 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t N0 = N / N1Number;
const index_t GemmK0Total = N0 * Ho * Wo;
const index_t GemmK0S =
math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock;
const index_t GemmK0Pad = GemmKBatch * GemmK0S;
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K));
const auto out_n0_ho_wo_k_n1_grid_desc =
transform_tensor_descriptor(out_n_ho_wo_k_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
make_pass_through_transform(Ho * Wo),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmk0total_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc,
make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)),
make_pass_through_transform(K),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmk0total_gemmm_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
make_pass_through_transform(GemmM),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
out_gemmk0pad_gemmm_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmM),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
......@@ -167,26 +207,50 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmn_grid_desc =
const auto in_n0_y_ho_x_wo_c_n1_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
make_pass_through_transform(Y),
make_pass_through_transform(Ho),
make_pass_through_transform(X),
make_pass_through_transform(Wo),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 6>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_n0_y_ho_x_wo_c_n1_grid_desc,
make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C)),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmk0total_gemmn_gemmk1_grid_desc,
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
make_pass_through_transform(GemmN),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
in_gemmk0pad_gemmn_gemmk1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
make_pass_through_transform(GemmN),
make_pass_through_transform(N1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
......@@ -205,7 +269,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......@@ -233,6 +297,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
......@@ -241,12 +308,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......@@ -274,6 +346,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
......@@ -282,10 +357,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
......@@ -353,17 +433,16 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
}
}
......@@ -422,14 +501,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting");
}
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
......@@ -444,6 +523,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
......@@ -465,7 +545,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -482,7 +562,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -502,7 +582,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -519,7 +599,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -562,6 +642,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return false;
}
// unmerge N to N0 and N1, where N1 equals to K1
if(!(arg.Conv_N_ % K1 == 0))
{
return false;
}
// vector store C matrix into global memory
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
{
......@@ -572,8 +658,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -486,13 +486,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
......@@ -572,15 +575,14 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
......@@ -703,8 +705,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
return false;
}
......
......@@ -540,7 +540,8 @@ struct
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
block_2_ctile_map_{},
block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
N01_{N01},
in_element_op_{in_element_op},
......@@ -575,8 +576,10 @@ struct
c0_grid_desc_m_n_ = descs[I3];
c1_grid_desc_m_n_ = descs[I4];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
......@@ -592,9 +595,6 @@ struct
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -689,14 +689,14 @@ struct
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -852,8 +852,7 @@ struct
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -548,9 +548,13 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
c0_grid_desc_m_n_ = descs[I3];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
......@@ -561,9 +565,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -649,14 +650,14 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -802,8 +803,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -520,18 +520,20 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
c_grid_desc_m_n_ = descs[I2];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -631,14 +633,14 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -774,8 +776,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -408,15 +408,16 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -469,14 +470,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -606,8 +607,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -259,50 +259,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
struct Block2CTileMapMaker
{
Block2CTileMapMaker(index_t num_batches) : num_batches_(num_batches) {}
__host__ __device__ constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(num_batches_),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(num_batches_, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
private:
index_t num_batches_;
};
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
InDataType,
......@@ -345,8 +301,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(Block2CTileMapMaker{1}.MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Argument
struct Argument : public BaseArgument
......@@ -398,18 +353,20 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
b_batch_stride_ = 0;
c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize();
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = Block2CTileMapMaker{num_subbatches_}.MakeBlock2CTileMap(
c_grid_desc_m_n_, M01, N01);
}
}
......@@ -457,16 +414,15 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
// todo: grid_size times arg.num_subbatches_
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.num_subbatches_;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) *
arg.num_subbatches_;
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
......@@ -565,8 +521,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -1073,13 +1073,15 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
......@@ -1129,13 +1131,16 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
......@@ -1194,14 +1199,17 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
......@@ -1286,15 +1294,14 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
......@@ -1418,8 +1425,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
return false;
}
......
......@@ -705,15 +705,16 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -766,14 +767,14 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -916,8 +917,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
......@@ -6,17 +6,19 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D1ElementwiseOperation>
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_d0,
void* p_d1,
DPtrsGlobal p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -26,20 +28,25 @@ struct DeviceGemmReduce : public BaseOperator
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D1ElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation>>;
DxsInElementwiseOperation,
DxsOutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -26,13 +26,14 @@ template <typename ALayout,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DDataType,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -67,10 +68,12 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation>
DxsInElementwiseOperation,
DxsOutElementwiseOperation>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
......@@ -380,15 +383,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType,
CDataType,
ReduceAccDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
......@@ -435,8 +438,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DDataType* p_d0_grid,
DDataType* p_d1_grid,
DPtrsGlobal p_ds_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -446,26 +448,29 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op)
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_d0_grid_{p_d0_grid},
p_d1_grid_{p_d1_grid},
p_ds_grid_{p_ds_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d1_element_op_{d1_element_op}
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -473,8 +478,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
......@@ -482,8 +485,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
DDataType* p_d0_grid_;
DDataType* p_d1_grid_;
DPtrsGlobal p_ds_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -495,7 +497,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D1ElementwiseOperation d1_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsOutElementwiseOperation dxs_out_element_op_;
};
// Invoker
......@@ -525,13 +528,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -543,11 +549,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -564,12 +571,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.p_ds_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -582,11 +589,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -603,12 +611,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.p_ds_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -635,8 +643,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
......@@ -648,8 +658,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
DDataType* p_d0,
DDataType* p_d1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -659,13 +668,13 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op)
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_d0,
p_d1,
p_dxs,
MRaw,
NRaw,
KRaw,
......@@ -675,7 +684,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d1_element_op};
dxs_in_element_op,
dxs_out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -684,8 +694,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_d0,
void* p_d1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -695,14 +704,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<DDataType*>(p_d0),
static_cast<DDataType*>(p_d1),
p_dxs,
MRaw,
NRaw,
KRaw,
......@@ -712,7 +721,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d1_element_op);
dxs_in_element_op,
dxs_out_element_op);
}
// polymorphic
......
......@@ -257,14 +257,16 @@ struct DeviceGemmXdl
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -310,14 +312,14 @@ struct DeviceGemmXdl
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -409,8 +411,7 @@ struct DeviceGemmXdl
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -218,8 +218,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
c_grid_desc_m_n_ =
DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
......@@ -230,9 +235,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -285,14 +287,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -400,8 +402,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -227,8 +227,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
c_grid_desc_m_n_ = descs[I2];
c0_grid_desc_m_n_ = descs[I3];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
......@@ -239,9 +244,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -294,14 +296,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -409,8 +411,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -256,8 +256,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
c0_grid_desc_m_n_ = descs[I3];
c1_grid_desc_m_n_ = descs[I4];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
......@@ -273,9 +278,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -336,14 +338,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
......@@ -461,8 +463,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -404,19 +404,19 @@ struct DeviceGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
......@@ -459,13 +459,16 @@ struct DeviceGemm_Xdl_CShuffle
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -555,8 +558,10 @@ struct DeviceGemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -332,17 +332,16 @@ struct DeviceGemmXdlSplitK
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
}
}
......@@ -395,14 +394,14 @@ struct DeviceGemmXdlSplitK
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
......@@ -532,8 +531,7 @@ struct DeviceGemmXdlSplitK
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -292,8 +292,7 @@ struct DeviceGemmXdlSplitKCShuffle
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor;
// Argument
struct Argument : public BaseArgument
......@@ -338,17 +337,16 @@ struct DeviceGemmXdlSplitKCShuffle
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
}
}
......@@ -401,14 +399,14 @@ struct DeviceGemmXdlSplitKCShuffle
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
......@@ -541,8 +539,7 @@ struct DeviceGemmXdlSplitKCShuffle
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic
......
......@@ -307,6 +307,11 @@ struct DeviceGroupedGemmXdl
struct GroupedGemmBlock2CTileMap
{
using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
static_assert(
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)),
typename GridwiseGemm::DefaultBlock2CTileMap>::value,
"Wrong! Should be the same type name");
GroupedGemmBlock2CTileMap()
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1);
......@@ -329,6 +334,18 @@ struct DeviceGroupedGemmXdl
make_multi_index(idx_top[I0] - BlockStart_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
private:
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
ck::index_t BlockStart_;
......@@ -400,22 +417,27 @@ struct DeviceGroupedGemmXdl
const auto c_grid_desc_m_n_ =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_);
const index_t grid_size_grp =
typename GroupedGemmBlock2CTileMap::UnderlyingBlock2CTileMap(
c_grid_desc_m_n_, M01, N01)
.CalculateGridSize(c_grid_desc_m_n_);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
const auto grouped_gemm_block_2_ctile_map_ =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
grouped_gemm_block_2_ctile_map_))
{
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
const auto grouped_gemm_block_2_ctile_map_ =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
......@@ -475,11 +497,11 @@ struct DeviceGroupedGemmXdl
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
if(!GridwiseGemm::CheckValidity(
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment