Commit 060c4f3a authored by aska-0096's avatar aska-0096
Browse files

Skip B Lds Gemm + MulD

parent 04c6a978
......@@ -87,8 +87,8 @@ using DeviceOpInstance =
8,
16,
16,
1,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -105,7 +105,7 @@ using DeviceOpInstance =
true,
1,
1,
S<1, 128, 1, 2>,
S<1, 16, 1, 16>,
8>;
int main(int argc, char* argv[])
......
......@@ -105,6 +105,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
// Describe how data read from Global memory
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() {
......@@ -115,12 +116,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
#endif
}();
const auto M = a_grid_desc_m_k.GetLength(I0);
......@@ -155,42 +157,56 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
const auto b_grid_desc_n_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(BEnableLds)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
const index_t K0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_k_n,
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
......@@ -245,7 +261,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
......@@ -260,7 +276,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType,
// InMemory Data Descriptor
AGridDesc,
BGridDesc_K0_N_K1,
BGridDesc,
DsGridDesc_M_N,
EGridDesc_M_N,
// ElementwiseOp Family
......@@ -329,7 +345,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc{},
b_grid_desc_k0_n_k1_{},
b_grid_desc{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
......@@ -342,7 +358,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
cde_element_op_{cde_element_op}
{
a_grid_desc = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
b_grid_desc = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
......@@ -359,7 +375,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
if(GridwiseOp::CheckValidity(a_grid_desc,
b_grid_desc_k0_n_k1_,
b_grid_desc,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_ctile_map_))
......@@ -382,7 +398,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// Tensor Descriptors
AGridDesc a_grid_desc;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
BGridDesc b_grid_desc;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -410,24 +426,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc{" << arg.a_grid_desc.GetLength(I0)
<< ", " << arg.a_grid_desc.GetLength(I1) << ", "
<< arg.a_grid_desc.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_))
......@@ -439,8 +439,17 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
const auto K = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
}
else
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I5);
}
}();
float ave_time = 0;
......@@ -453,7 +462,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc>,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename DeviceOp::BGridDesc>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
......@@ -475,7 +484,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
......@@ -492,7 +501,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
typename GridwiseOp::DsGridPointer,
EDataType,
remove_reference_t<typename DeviceOp::AGridDesc>,
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename DeviceOp::BGridDesc>,
remove_reference_t<
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
......@@ -514,7 +523,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
......@@ -555,7 +564,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
return GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_);
......
......@@ -344,22 +344,6 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc_{" << arg.a_grid_desc_.GetLength(I0)
<< ", " << arg.a_grid_desc_.GetLength(I1) << ", "
<< arg.a_grid_desc_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
......@@ -372,7 +356,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto GetK = [&]() {
const auto K = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
......@@ -382,8 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
};
const auto K = GetK();
}();
float ave_time = 0;
......
......@@ -46,7 +46,7 @@ __global__ void
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -85,7 +85,7 @@ __global__ void
p_e_grid + e_batch_offset,
p_shared,
a_grid_desc,
b_grid_desc_k0_n_k1,
b_grid_desc,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
a_element_op,
......@@ -99,7 +99,7 @@ __global__ void
ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1;
ignore = b_grid_desc;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = a_element_op;
......@@ -116,7 +116,7 @@ template <typename GridwiseOp,
typename DsPointer,
typename EDataType,
typename AGridDesc,
typename BGridDesc_K0_N_K1,
typename BGridDesc,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
......@@ -136,7 +136,7 @@ __global__ void
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AGridDesc a_grid_desc,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const BGridDesc b_grid_desc,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -177,7 +177,7 @@ __global__ void
p_e_grid + e_batch_offset,
p_shared,
a_grid_desc,
b_grid_desc_k0_n_k1,
b_grid_desc,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
......@@ -194,7 +194,7 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1;
ignore = b_grid_desc;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
......@@ -208,7 +208,7 @@ template <typename GridwiseOp,
typename DsPointer,
typename EDataType,
typename AGridDesc,
typename BGridDesc_K0_N_K1,
typename BGridDesc,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
......@@ -226,7 +226,7 @@ __global__ void
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AGridDesc a_grid_desc,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const BGridDesc b_grid_desc,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -245,7 +245,7 @@ __global__ void
p_e_grid,
p_shared,
a_grid_desc,
b_grid_desc_k0_n_k1,
b_grid_desc,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
......@@ -258,7 +258,7 @@ __global__ void
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1;
ignore = b_grid_desc;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
......@@ -277,7 +277,7 @@ template < // DataType Family
typename EDataType,
// InMemory Data Descriptor
typename AGridDesc,
typename BGridDesc_K0_N_K1,
typename BGridDesc,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
// ElementwiseOp Family
......@@ -385,6 +385,40 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return a_block_desc;
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor()
{
constexpr auto b_block_desc = [&]() {
if constexpr(BEnableLds)
{
// K0->N->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto max_lds_align = K1;
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1));
}
}();
return b_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
......@@ -478,44 +512,56 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return a_wave_desc;
}
template <typename BBlockDesc_BK0_N_BK1>
template <typename BBlockDesc_>
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
MakeBWaveDescriptor(const BBlockDesc_&)
{
constexpr auto B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
constexpr auto b_wave_desc = [&]() {
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
constexpr auto K0PerBlock = KPerBlock / K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<NRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
return b_block_desc_k0perblock_nperblock_k1;
return b_wave_desc;
}
__host__ __device__ static constexpr auto
......@@ -551,7 +597,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const BGridDesc& b_grid_desc,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
......@@ -581,17 +627,17 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const auto GetBProblemsizeNK = [&]() {
if constexpr(BEnableLds)
{
return make_tuple(b_grid_desc_k0_n_k1.GetLength(I1),
b_grid_desc_k0_n_k1.GetLength(I0) *
b_grid_desc_k0_n_k1.GetLength(I2));
return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc.GetLength(I0) *
b_grid_desc.GetLength(I2));
}
else
{
return make_tuple(
b_grid_desc_k0_n_k1.GetLength(I1) * b_grid_desc_k0_n_k1.GetLength(I2) *
b_grid_desc_k0_n_k1.GetLength(I4),
b_grid_desc_k0_n_k1.GetLength(I0) * b_grid_desc_k0_n_k1.GetLength(I3) *
b_grid_desc_k0_n_k1.GetLength(I5));
b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5));
}
};
......@@ -702,7 +748,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
: 0;
static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
MakeBBlockDescriptor().GetElementSpaceSize(),
max_lds_align)
: 0;
......@@ -737,7 +783,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const BGridDesc& b_grid_desc,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
......@@ -753,7 +799,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
p_b_grid, b_grid_desc.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -789,7 +835,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}();
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto b_block_desc = MakeBBlockDescriptor();
auto a_block_trait = [&](){
// A matrix blockwise copy
......@@ -886,7 +932,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_grid_desc),
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
......@@ -898,7 +944,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_n_k1,
b_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc,
......@@ -909,22 +955,36 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
else
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v4<BDataType,
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_grid_desc),
decltype(b_block_desc),
Sequence<Number<K0PerBlock>{},
Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{},
I1,
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2>,
2,
Sequence<0, 1, 2, 3, 4, 5>,
5,
BBlockTransferSrcScalarPerVector,
1>(
make_multi_index(0, get_thread_local_1d_id()/32 * 16 + get_thread_local_1d_id() % 16, 0));
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc,
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b_block_buf, b_blockwise_copy);
}
......@@ -945,7 +1005,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
BDataType,
AccDataType,
decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc)),
decltype(MakeBWaveDescriptor(b_block_desc)),
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -973,7 +1033,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_grid_desc,
b_block_desc,
b_blockwise_copy,
b_grid_buf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment