Commit a38ce024 authored by aska-0096's avatar aska-0096
Browse files

batched gemm ported

parent 686212eb
......@@ -70,12 +70,12 @@ using DeviceOpInstanceKKNN =
256,
128,
128,
4,
32,
8,
16,
16,
4,
2,
1,
8,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -92,7 +92,7 @@ using DeviceOpInstanceKKNN =
true,
1,
1,
S<1, 32, 1, 8>,
S<1, 128, 1, 2>,
8>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
......
......@@ -5,6 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......@@ -14,3 +17,8 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
if(GPU_TARGETS MATCHES "gfx1100")
add_custom_target(example_gemm_scale_softmax_gemm_wmma)
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16)
endif()
\ No newline at end of file
......@@ -76,10 +76,10 @@ template <index_t NumDimG,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
......@@ -123,14 +123,23 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
......@@ -158,36 +167,69 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
if constexpr(ASpec == TensorSpecialization::Packed)
const auto a_grid_desc_m_k = [&](){
if constexpr(ASpec == TensorSpecialization::Packed)
{
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(M, K),
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
a_grid_desc_ms_ks,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
}();
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(AEnableLds)
{
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
make_tuple(M, K),
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
a_grid_desc_ms_ks,
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
make_tuple(mDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
static auto MakeBGridDescriptor_K0_N_K1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
......@@ -214,31 +256,45 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
if constexpr(BSpec == TensorSpecialization::Packed)
{
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
make_tuple(N, K),
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
b_grid_desc_ns_ks,
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_n_k = [&](){
if constexpr(BSpec == TensorSpecialization::Packed)
{
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
make_tuple(N, K),
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const auto b_grid_desc_ns_ks =
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
b_grid_desc_ns_ks,
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
make_tuple(nDimIds, kDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
......@@ -393,8 +449,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
......@@ -449,42 +503,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EGridDesc_G_M_N e_grid_desc_g_m_n_;
};
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({},{}));
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1({},{}));
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
......@@ -496,7 +516,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DsDataType,
EDataType,
// InMemory Data Descriptor
AGridDesc_K0_M_K1,
AGridDesc,
BGridDesc_K0_N_K1,
DsGridDesc_M_N,
EGridDesc_M_N,
......@@ -508,9 +528,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
KPerBlock,
MPerWmma,
NPerWmma,
K1,
MRepeat,
NRepeat,
......@@ -523,6 +543,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -531,6 +552,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......@@ -564,16 +586,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{},
b_grid_desc_n_k_{},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{},
ds_grid_desc_g_m_n_{
DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
e_grid_desc_g_m_n_{
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
e_grid_desc_mblock_mperblock_nblock_nperblock{},
block_2_ctile_map_{},
......@@ -600,10 +620,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
});
a_grid_desc_m_k_ =
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
......@@ -611,8 +629,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
......@@ -644,15 +660,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType* p_e_grid_;
// Tensor Descriptors
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
AGridDesc a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
EGridDesc_G_M_N e_grid_desc_g_m_n_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -712,7 +726,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BDataType,
typename GridwiseOp::DsGridPointer,
EDataType,
DeviceOp::AGridDesc_K0_M_K1,
DeviceOp::AGridDesc,
DeviceOp::BGridDesc_K0_N_K1,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -975,10 +989,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< KPerBlock << ", "
<< K1 << ", "
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
......
......@@ -89,8 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
// static constexpr auto AEnableLds = true;
// static constexpr auto BEnableLds = true;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......@@ -124,7 +125,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(a_grid_desc_m_k.GetLength(I0))),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
......
......@@ -296,20 +296,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{},
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
......@@ -782,6 +789,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MRepeat,
LRepeat,
KPack,
AEnableLds,
B0EnableLds,
true>{}; // C' = B' x A'
......@@ -968,6 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MRepeat,
NRepeat,
KPack,
false,
B1EnableLds,
true>{make_tuple(0, 0, 0, 0, 0)};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
......
......@@ -69,7 +69,7 @@ __global__ void
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
DsPointer p_ds_grid_grp;
......@@ -148,7 +148,7 @@ __global__ void
const Block2CTileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -237,7 +237,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
......@@ -451,20 +451,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{},
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
......@@ -540,19 +547,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
const index_t gemm_bytes_end =
SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType)+
SharedMemTrait::b_block_space_size_aligned * sizeof(BDataType);
const index_t c_block_bytes_end =
SharedMemTrait::c_shuffle_block_space_size * sizeof(CShuffleDataType);
return math::max(gemm_bytes_end, c_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
......@@ -650,7 +644,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
......@@ -704,11 +698,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
static constexpr auto a_block_space_size_aligned =
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
max_lds_align): 0;
max_lds_align)
: 0;
static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
max_lds_align): 0;
max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
......@@ -719,6 +715,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
.GetElementSpaceSize();
static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType));
};
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
......@@ -796,7 +797,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared),
static_cast<ADataType*>(p_shared),
a_block_desc.GetElementSpaceSize());
auto a_blockwise_copy =
......@@ -807,8 +808,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA,
/* typename SrcData, */ ADataType,
/* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
......@@ -835,13 +836,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA,
FloatA,
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
decltype(a_grid_desc),
decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{},
......@@ -872,7 +873,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
static_cast<BDataType*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy =
......@@ -883,8 +884,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
FloatB,
BDataType,
BDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
......@@ -909,11 +910,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
else
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
ThreadwiseTensorSliceTransfer_v4<BDataType,
BDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc),
Sequence<Number<K0PerBlock>{},
......@@ -952,38 +953,35 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
KPack,
AEnableLds,
BEnableLds>{};
// Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
/*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<ADataType*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<BDataType*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
/*******************************************************************************/
// Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc_k0perblock_mperblock_k1,
a_block_desc,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0perblock_nperblock_k1,
b_block_desc,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
KBlockMainLoop);
/*******************************************************************************/
// write out to C, implement shuffle
{
......
......@@ -56,8 +56,6 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Enabled, Mat-B Lds Enabled\n");
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
......@@ -306,8 +304,6 @@ struct GridwiseGemmPipeline_v1<1, false, true>
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
if(get_thread_local_1d_id()<32);
printf("Mat-A Lds Disabled, Mat-B Lds Enabled\n");
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
......
......@@ -731,7 +731,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc,
a_blockwise_copy,
......@@ -746,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
KBlockMainLoop);
/*******************************************************************************/
// write out to C, implement shuffle
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment