Commit 0530fd66 authored by Chao Liu's avatar Chao Liu
Browse files

update gemm multi-d

parent 05c484e2
...@@ -11,7 +11,7 @@ namespace ck { ...@@ -11,7 +11,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// GEMM: // Convolution Forward:
// input : input image A[N, Hi, Wi, C], // input : input image A[N, Hi, Wi, C],
// input : weight B[K, Y, X, C], // input : weight B[K, Y, X, C],
// input : D0[N, Ho, Wo, K], D1[N, Ho, Wo, K], ... // input : D0[N, Ho, Wo, K], D1[N, Ho, Wo, K], ...
......
...@@ -77,18 +77,18 @@ __global__ void ...@@ -77,18 +77,18 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map); block_2_etile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map; ignore = block_2_etile_map;
#endif #endif
} }
} // namespace } // namespace
...@@ -197,18 +197,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -197,18 +197,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto wei_gemmn_gemmk_grid_desc = const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc); matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
const auto GemmN = wei_gemmn_gemmk_grid_desc.GetLength(I0); return wei_gemmn_gemmk_grid_desc;
const auto GemmK = wei_gemmn_gemmk_grid_desc.GetLength(I1);
const index_t GemmK0 = GemmK / GemmK1Number;
// wei_gemmk0_gemmn_gemmk1_grid_desc
return transform_tensor_descriptor(
wei_gemmn_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
static auto GetOutputTensorDescriptor(index_t GemmMRaw, index_t GemmN) static auto GetOutputTensorDescriptor(index_t GemmMRaw, index_t GemmN)
...@@ -250,18 +239,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -250,18 +239,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -286,19 +264,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -286,19 +264,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
} }
else else
{ {
...@@ -337,19 +303,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -337,19 +303,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_pass_through_transform(GemmM),
make_unmerge_transform(make_tuple(GemmK0, GemmK1Number))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
} }
} }
...@@ -384,18 +338,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -384,18 +338,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -422,19 +365,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -422,19 +365,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
} }
else else
{ {
...@@ -482,19 +413,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -482,19 +413,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_pass_through_transform(GemmM),
make_unmerge_transform(make_tuple(GemmK0, GemmK1Number))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
} }
} }
...@@ -532,18 +451,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -532,18 +451,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
// in_gemmk0_gemmm_gemmk1_grid_desc
return transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0) ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -573,19 +481,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -573,19 +481,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
} }
else else
{ {
...@@ -646,19 +542,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -646,19 +542,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const auto in_gemmm_gemmk_grid_desc = const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
const auto GemmM = in_gemmm_gemmk_grid_desc.GetLength(I0); return in_gemmm_gemmk_grid_desc;
const auto GemmK = in_gemmm_gemmk_grid_desc.GetLength(I1);
const auto GemmK0 = GemmK / GemmK1Number;
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmm_gemmk_grid_desc,
make_tuple(make_pass_through_transform(GemmM),
make_unmerge_transform(make_tuple(GemmK0, GemmK1Number))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
return in_gemmk0_gemmm_gemmk1_grid_desc;
} }
} }
...@@ -696,11 +580,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -696,11 +580,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const index_t GemmNRaw = K; const index_t GemmNRaw = K;
const index_t GemmKRaw = GetGemmKRaw(C, filter_spatial_lengths); const index_t GemmKRaw = GetGemmKRaw(C, filter_spatial_lengths);
// TODO: remove
assert(GemmKRaw % GemmK1Number == 0);
// A: // A:
const auto in_gemmk0_gemmm_gemmk1_grid_desc = const auto in_gemmm_gemmk_grid_desc =
GetInputTensorDescriptor<NDimSpatial>(N, GetInputTensorDescriptor<NDimSpatial>(N,
C, C,
GemmMRaw, GemmMRaw,
...@@ -714,15 +595,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -714,15 +595,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_right_pads); input_right_pads);
// B: // B:
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = const auto wei_gemmn_gemmk_grid_desc = GetWeightTensorDescriptor(GemmNRaw, GemmKRaw);
GetWeightTensorDescriptor(GemmNRaw, GemmKRaw);
// E: // E:
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmNRaw); const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmNRaw);
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(
wei_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmk_grid_desc, wei_gemmn_gemmk_grid_desc, out_gemmm_gemmn_grid_desc);
out_gemmm_gemmn_grid_desc);
} }
template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> template <index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
...@@ -748,9 +627,9 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -748,9 +627,9 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
using ABEGridDescs = decltype(GetABEGridDesc<NDimSpatial>()); using ABEGridDescs = decltype(GetABEGridDesc<NDimSpatial>());
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(ABEGridDescs{}[I0])>; using AGridDesc_M_K = remove_cvref_t<decltype(ABEGridDescs{}[I0])>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(ABEGridDescs{}[I1])>; using BGridDesc_N_K = remove_cvref_t<decltype(ABEGridDescs{}[I1])>;
using EGridDesc_M_N = remove_cvref_t<decltype(ABEGridDescs{}[I2])>; using EGridDesc_M_N = remove_cvref_t<decltype(ABEGridDescs{}[I2])>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
...@@ -763,8 +642,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -763,8 +642,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_M_K,
BGridDesc_BK0_N_BK1, BGridDesc_N_K,
EGridDesc_M_N, EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -799,11 +678,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -799,11 +678,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
#if 0 using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
using Block2ETileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, EGridDesc_M_N>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
#else using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
#endif
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -856,9 +736,14 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -856,9 +736,14 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
a_grid_desc_ak0_m_ak1_ = descs[I0]; const auto a_grid_desc_m_k = descs[I0];
b_grid_desc_bk0_n_bk1_ = descs[I1]; const auto b_grid_desc_n_k = descs[I1];
e_grid_desc_m_n_ = descs[I2]; e_grid_desc_m_n_ = descs[I2];
a_grid_desc_ak0_m_ak1_ =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
b_grid_desc_bk0_n_bk1_ =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_}; block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_};
...@@ -917,7 +802,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -917,7 +802,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 1
{ {
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...@@ -1010,6 +895,20 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle ...@@ -1010,6 +895,20 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 1
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(ck::get_device_name() == "gfx908") if(ck::get_device_name() == "gfx908")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
namespace ck { namespace ck {
// GEMM: // GEMM:
// input : A[AK0, M, AK1] // input : A[AK0PerBlock, M, AK1]
// input : B[AK0, N, AK1] // input : B[AK0PerBlock, N, AK1]
// input : D0[M, N], D1[M, N], ... // input : D0[M, N], D1[M, N], ...
// output : E[M, N] // output : E[M, N]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
...@@ -35,8 +35,8 @@ template <typename FloatAB, ...@@ -35,8 +35,8 @@ template <typename FloatAB,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_M_K,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_N_K,
typename EGridDesc_M_N, typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1), make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
} }
...@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1), make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
} }
...@@ -164,8 +164,65 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -164,8 +164,65 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap> template <typename AGridDesc_AK0_M_AK1, typename BGridDesc_BK0_N_BK1, typename Block2ETileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
...@@ -210,32 +267,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -210,32 +267,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
__host__ __device__ static constexpr auto using DefaultAGridDesc_AK0_M_AK1 =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
{ using DefaultBGridDesc_BK0_N_BK1 =
const auto M = e_grid_desc_m_n.GetLength(I0); remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
...@@ -245,7 +280,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -245,7 +280,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop, typename Block2ETileMap> template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -316,7 +354,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -316,7 +354,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -347,7 +385,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -347,7 +385,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment