"vscode:/vscode.git/clone" did not exist on "e9203f9157d3f296911a6d5b122e6cf72a7cd6fe"
Commit 04c6a978 authored by aska-0096's avatar aska-0096
Browse files

Skip B-Lds real gemm

parent f00dab9f
......@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // M Repeat
4, // N-Repeat
8, // M Repeat
1, // N-Repeat
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>,
S<1, 16, 1, 16>,
8>;
// clang-format on
......
......@@ -106,12 +106,13 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
#endif
}();
const auto M = a_grid_desc_m_k.GetLength(I0);
......@@ -146,34 +147,57 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
const auto b_grid_desc_n_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
if constexpr(BEnableLds)
{
const index_t K0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
......@@ -196,7 +220,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
......@@ -209,7 +233,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc,
BGridDesc_K0_N_K1,
BGridDesc,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
......@@ -281,7 +305,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
......@@ -301,7 +325,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc a_grid_desc_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
BGridDesc b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -371,7 +395,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
......@@ -404,7 +428,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
......
......@@ -309,9 +309,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
......@@ -364,6 +364,100 @@ struct GridwiseGemmPipeline_v1<1, false, true>
template <>
struct GridwiseGemmPipeline_v1<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <>
......
......@@ -18,11 +18,11 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename ADataType,
typename BDataType,
typename CDataType,
typename AGridDesc,
typename BGridDesc_K0_N_K1,
typename BGridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -33,11 +33,11 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_wmma(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const BGridDesc b_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
......@@ -53,7 +53,7 @@ __global__ void
p_c_grid,
p_shared,
a_grid_desc,
b_grid_desc_k0_n_k1,
b_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
......@@ -64,7 +64,7 @@ __global__ void
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1;
ignore = b_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
......@@ -74,14 +74,14 @@ __global__ void
}
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename FloatCShuffle,
typename FloatC,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc,
typename BGridDesc_K0_N_K1,
typename BGridDesc,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -181,6 +181,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return a_block_desc;
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor()
{
constexpr auto b_block_desc = [&]() {
if constexpr(BEnableLds)
{
// K0->N->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto max_lds_align = K1;
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1));
}
}();
return b_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
......@@ -292,43 +326,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return a_wave_desc;
}
template <typename BBlockDesc_BK0_N_BK1>
template <typename BBlockDesc_>
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
MakeBWaveDescriptor(const BBlockDesc_&)
{
constexpr auto B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
constexpr auto K0PerBlock = KPerBlock / K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
constexpr auto b_wave_desc = [&]() {
if constexpr(BEnableLds)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<NRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
return b_block_desc_k0perblock_nperblock_k1;
return b_wave_desc;
}
__host__ __device__ static constexpr auto
......@@ -349,7 +396,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const BGridDesc& b_grid_desc,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
......@@ -378,17 +425,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const auto GetBProblemsizeNK = [&]() {
if constexpr(BEnableLds)
{
return make_tuple(b_grid_desc_k0_n_k1.GetLength(I1),
b_grid_desc_k0_n_k1.GetLength(I0) *
b_grid_desc_k0_n_k1.GetLength(I2));
return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc.GetLength(I0) *
b_grid_desc.GetLength(I2));
}
else
{
return make_tuple(
b_grid_desc_k0_n_k1.GetLength(I1) * b_grid_desc_k0_n_k1.GetLength(I2) *
b_grid_desc_k0_n_k1.GetLength(I4),
b_grid_desc_k0_n_k1.GetLength(I0) * b_grid_desc_k0_n_k1.GetLength(I3) *
b_grid_desc_k0_n_k1.GetLength(I5));
b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5));
}
};
......@@ -484,9 +531,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
max_lds_align)
: 0;
static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
max_lds_align)
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0;
......@@ -500,18 +546,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(FloatCShuffle),
a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(BDataType));
};
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const BGridDesc& b_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
......@@ -525,7 +571,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
p_b_grid, b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......@@ -554,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}();
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto b_block_desc = MakeBBlockDescriptor();
auto a_block_trait = [&](){
// A matrix blockwise copy
......@@ -562,7 +608,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared),
static_cast<ADataType*>(p_shared),
SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy =
......@@ -573,8 +619,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA,
/* typename SrcData, */ ADataType,
/* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
......@@ -601,13 +647,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA,
FloatA,
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
decltype(a_grid_desc),
decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{},
......@@ -638,7 +684,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::b_block_space_offset,
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
SharedMemTrait::b_block_space_size_aligned);
auto b_blockwise_copy =
......@@ -649,9 +695,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
FloatB,
decltype(b_grid_desc_k0_n_k1),
BDataType,
BDataType,
decltype(b_grid_desc),
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
......@@ -663,7 +709,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_n_k1,
b_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc,
......@@ -674,23 +720,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
else
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_grid_desc_k0_n_k1),
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
decltype(b_grid_desc),
decltype(b_block_desc),
Sequence<Number<K0PerBlock>{},
Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{},
I1,
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2>,
2,
Sequence<0, 1, 2, 3, 4, 5>,
5,
BBlockTransferSrcScalarPerVector,
1>(
make_multi_index(0, get_thread_local_1d_id()/32 * 16 + get_thread_local_1d_id() % 16, 0));
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc,
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b_block_buf, b_blockwise_copy);
}
};
......@@ -706,11 +766,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
auto blockwise_gemm =
BlockwiseGemmWMMA<BlockSize,
FloatA,
FloatB,
FloatAcc,
ADataType,
BDataType,
AccDataType,
decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc)),
decltype(MakeBWaveDescriptor(b_block_desc)),
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -738,7 +798,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_grid_desc,
b_block_desc,
b_blockwise_copy,
b_grid_buf,
......@@ -768,7 +828,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
SharedMemTrait::c_shuffle_block_space_size);
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
......@@ -815,8 +875,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatCShuffle,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
ck::tensor_operation::element_wise::PassThrough,
......@@ -854,8 +914,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment