Commit a38ce024 authored by aska-0096's avatar aska-0096
Browse files

batched gemm ported

parent 686212eb
...@@ -70,12 +70,12 @@ using DeviceOpInstanceKKNN = ...@@ -70,12 +70,12 @@ using DeviceOpInstanceKKNN =
256, 256,
128, 128,
128, 128,
4, 32,
8, 8,
16, 16,
16, 16,
4, 1,
2, 8,
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -92,7 +92,7 @@ using DeviceOpInstanceKKNN = ...@@ -92,7 +92,7 @@ using DeviceOpInstanceKKNN =
true, true,
1, 1,
1, 1,
S<1, 32, 1, 8>, S<1, 128, 1, 2>,
8>; 8>;
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
......
...@@ -5,6 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 ...@@ -5,6 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
...@@ -14,3 +17,8 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft ...@@ -14,3 +17,8 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
if(GPU_TARGETS MATCHES "gfx1100")
add_custom_target(example_gemm_scale_softmax_gemm_wmma)
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16)
endif()
\ No newline at end of file
...@@ -76,10 +76,10 @@ template <index_t NumDimG, ...@@ -76,10 +76,10 @@ template <index_t NumDimG,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t K0PerBlock, ck::index_t KPerBlock,
ck::index_t K1, ck::index_t K1,
ck::index_t MPerWMMA, ck::index_t MPerWmma,
ck::index_t NPerWMMA, ck::index_t NPerWmma,
ck::index_t MRepeat, ck::index_t MRepeat,
ck::index_t NRepeat, ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
...@@ -123,14 +123,23 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -123,14 +123,23 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{ {
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
...@@ -158,36 +167,69 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -158,36 +167,69 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
if constexpr(ASpec == TensorSpecialization::Packed) const auto a_grid_desc_m_k = [&](){
if constexpr(ASpec == TensorSpecialization::Packed)
{
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(M, K),
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
a_grid_desc_ms_ks,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
}();
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(AEnableLds)
{ {
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); const index_t K0 = K / K1;
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( return transform_tensor_descriptor(
make_tuple(M, K), a_grid_desc_m_k,
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}], make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}])); make_pass_through_transform(M)),
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else else
{ {
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] constexpr auto A_KRow = WmmaK / K1;
const auto a_grid_desc_ms_ks = const auto A_KWmma = K / WmmaK;
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
const auto M0 = M / MPerBlock;
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_ms_ks, a_grid_desc_m_k,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_tuple(mDimIds, kDimIds), make_unmerge_transform(
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
} }
} }
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_K0_N_K1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{ {
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
...@@ -214,31 +256,45 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -214,31 +256,45 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for N0, N1, ... // lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
if constexpr(BSpec == TensorSpecialization::Packed) const auto b_grid_desc_n_k = [&](){
{ if constexpr(BSpec == TensorSpecialization::Packed)
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); {
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
make_tuple(N, K), const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}], make_tuple(N, K),
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}])); make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
} return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
else }
{ else
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] {
const auto b_grid_desc_ns_ks = // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
b_grid_desc_ns_ks, const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), b_grid_desc_ns_ks,
make_tuple(nDimIds, kDimIds), make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); const auto N = b_grid_desc_n_k.GetLength(I0);
} const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
...@@ -393,8 +449,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -393,8 +449,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
...@@ -449,42 +503,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -449,42 +503,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
}; };
// A desc for source in blockwise copy using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({},{}));
template <typename AGridDesc_M_K> using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1({},{}));
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
// GridwiseOp // GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
...@@ -496,7 +516,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -496,7 +516,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DsDataType, DsDataType,
EDataType, EDataType,
// InMemory Data Descriptor // InMemory Data Descriptor
AGridDesc_K0_M_K1, AGridDesc,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
DsGridDesc_M_N, DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
...@@ -508,9 +528,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -508,9 +528,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Tiling Family // Tiling Family
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, KPerBlock,
MPerWMMA, MPerWmma,
NPerWMMA, NPerWmma,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
...@@ -523,6 +543,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -523,6 +543,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -531,6 +552,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -531,6 +552,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle, CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle, CShuffleNRepeatPerShuffle,
...@@ -564,16 +586,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -564,16 +586,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_n_k_{}, b_grid_desc_k0_n_k1_{},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{}, e_grid_desc_m_n_{},
ds_grid_desc_g_m_n_{ ds_grid_desc_g_m_n_{
DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
e_grid_desc_g_m_n_{ e_grid_desc_g_m_n_{
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock{}, ds_grid_desc_mblock_mperblock_nblock_nperblock{},
e_grid_desc_mblock_mperblock_nblock_nperblock{}, e_grid_desc_mblock_mperblock_nblock_nperblock{},
block_2_ctile_map_{}, block_2_ctile_map_{},
...@@ -600,10 +620,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -600,10 +620,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
}); });
a_grid_desc_m_k_ = a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ = ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
...@@ -611,8 +629,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -611,8 +629,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
e_grid_desc_m_n_ = e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
...@@ -644,15 +660,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -644,15 +660,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType* p_e_grid_; EDataType* p_e_grid_;
// Tensor Descriptors // Tensor Descriptors
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc a_grid_desc_k0_m_k1_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
DsGridDesc_G_M_N ds_grid_desc_g_m_n_; DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
EGridDesc_G_M_N e_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock; ds_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -712,7 +726,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -712,7 +726,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BDataType, BDataType,
typename GridwiseOp::DsGridPointer, typename GridwiseOp::DsGridPointer,
EDataType, EDataType,
DeviceOp::AGridDesc_K0_M_K1, DeviceOp::AGridDesc,
DeviceOp::BGridDesc_K0_N_K1, DeviceOp::BGridDesc_K0_N_K1,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -975,10 +989,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -975,10 +989,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << KPerBlock << ", "
<< K1 << ", " << K1 << ", "
<< MPerWMMA << ", " << MPerWmma << ", "
<< NPerWMMA << ", " << NPerWmma << ", "
<< MRepeat << ", " << MRepeat << ", "
<< NRepeat << NRepeat
<< ">" << ">"
......
...@@ -89,8 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -89,8 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds = NWaves == 1 ? false : true; static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true; static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// static constexpr auto AEnableLds = true; // Force enable LDS if uncommented following
// static constexpr auto BEnableLds = true; // AEnableLds = true;
// BEnableLds = true;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
...@@ -124,7 +125,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -124,7 +125,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(a_grid_desc_m_k.GetLength(I0))), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
......
...@@ -296,20 +296,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -296,20 +296,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)), make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}), make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{}, make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple( make_tuple(Sequence<>{},
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -782,6 +789,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -782,6 +789,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MRepeat, MRepeat,
LRepeat, LRepeat,
KPack, KPack,
AEnableLds,
B0EnableLds,
true>{}; // C' = B' x A' true>{}; // C' = B' x A'
...@@ -968,6 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -968,6 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack, KPack,
false,
B1EnableLds,
true>{make_tuple(0, 0, 0, 0, 0)}; true>{make_tuple(0, 0, 0, 0, 0)};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
......
...@@ -69,7 +69,7 @@ __global__ void ...@@ -69,7 +69,7 @@ __global__ void
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
DsPointer p_ds_grid_grp; DsPointer p_ds_grid_grp;
...@@ -148,7 +148,7 @@ __global__ void ...@@ -148,7 +148,7 @@ __global__ void
const Block2CTileMap block_2_etile_map) const Block2CTileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -237,7 +237,7 @@ __global__ void ...@@ -237,7 +237,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
...@@ -451,20 +451,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -451,20 +451,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)), make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}), make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{}, make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<5>{}), Sequence<5>{}),
make_tuple( make_tuple(Sequence<>{},
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -540,19 +547,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -540,19 +547,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
const index_t gemm_bytes_end =
SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType)+
SharedMemTrait::b_block_space_size_aligned * sizeof(BDataType);
const index_t c_block_bytes_end =
SharedMemTrait::c_shuffle_block_space_size * sizeof(CShuffleDataType);
return math::max(gemm_bytes_end, c_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
...@@ -650,7 +644,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -650,7 +644,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -704,11 +698,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -704,11 +698,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
static constexpr auto a_block_space_size_aligned = static constexpr auto a_block_space_size_aligned =
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
max_lds_align): 0; max_lds_align)
: 0;
static constexpr auto b_block_space_size_aligned = static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple( BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(), GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
max_lds_align): 0; max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned; static constexpr auto b_block_space_offset = a_block_space_size_aligned;
...@@ -719,6 +715,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -719,6 +715,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
.GetElementSpaceSize(); .GetElementSpaceSize();
static constexpr auto c_shuffle_block_space_offset = 0; static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType));
}; };
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
...@@ -796,7 +797,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -796,7 +797,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared), static_cast<ADataType*>(p_shared),
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -807,8 +808,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -807,8 +808,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>, /* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA, /* typename SrcData, */ ADataType,
/* typename DstData, */ FloatA, /* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(a_grid_desc), /* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc), /* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
...@@ -835,13 +836,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -835,13 +836,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical // Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA, ThreadwiseTensorSliceTransfer_v2<ADataType,
FloatA, ADataType,
decltype(a_grid_desc), decltype(a_grid_desc),
decltype(a_block_desc), decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
...@@ -872,7 +873,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -872,7 +873,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned, static_cast<BDataType*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy = auto b_blockwise_copy =
...@@ -883,8 +884,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -883,8 +884,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatB, BDataType,
FloatB, BDataType,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc), decltype(b_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -909,11 +910,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -909,11 +910,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
else else
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v4<FloatB, ThreadwiseTensorSliceTransfer_v4<BDataType,
FloatB, BDataType,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc), decltype(b_block_desc),
Sequence<Number<K0PerBlock>{}, Sequence<Number<K0PerBlock>{},
...@@ -952,38 +953,35 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -952,38 +953,35 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
NPerWmma, NPerWmma,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
AEnableLds,
BEnableLds>{};
// Prepare Register for C matrix // Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
/*******************************************************************************/ /*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<ADataType*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<BDataType*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc_k0perblock_mperblock_k1, a_block_desc,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
b_block_desc_k0perblock_nperblock_k1, b_block_desc,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
K0BlockMainLoop); KBlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
......
...@@ -56,8 +56,6 @@ struct GridwiseGemmPipeline_v1<1, true, true> ...@@ -56,8 +56,6 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Enabled, Mat-B Lds Enabled\n");
// preload data into LDS // preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
...@@ -306,8 +304,6 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -306,8 +304,6 @@ struct GridwiseGemmPipeline_v1<1, false, true>
}, },
Number<a_block_desc.GetLengths().GetSize()>{}); Number<a_block_desc.GetLengths().GetSize()>{});
#endif #endif
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Disabled, Mat-B Lds Enabled\n");
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0); constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf; auto a_block_buf_switch = a_block_buf;
......
...@@ -731,7 +731,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -731,7 +731,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc, a_block_desc,
a_blockwise_copy, a_blockwise_copy,
...@@ -746,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -746,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
K0BlockMainLoop); KBlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment