Commit 04c6a978 authored by aska-0096's avatar aska-0096
Browse files

Skip B-Lds real gemm

parent f00dab9f
...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
2, // M Repeat 8, // M Repeat
4, // N-Repeat 1, // N-Repeat
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>, S<1, 16, 1, 16>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -106,12 +106,13 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -106,12 +106,13 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
} }
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
} }
#endif
}(); }();
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
...@@ -146,34 +147,57 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -146,34 +147,57 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
} }
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_n_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB)); make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1)); make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
}(); }();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0); assert(K % K1 == 0);
const index_t K0 = K / K1;
if constexpr(BEnableLds)
return transform_tensor_descriptor( {
b_grid_desc_n_k, const index_t K0 = K / K1;
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), return transform_tensor_descriptor(
make_tuple(Sequence<1>{}, Sequence<0>{}), b_grid_desc_n_k,
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -196,7 +220,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -196,7 +220,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
...@@ -209,7 +233,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -209,7 +233,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc, AGridDesc,
BGridDesc_K0_N_K1, BGridDesc,
CGridDesc_M_N, CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -281,7 +305,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -281,7 +305,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{ {
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA); a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ = b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
...@@ -301,7 +325,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -301,7 +325,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc a_grid_desc_; AGridDesc a_grid_desc_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock; c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -371,7 +395,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -371,7 +395,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType, BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
...@@ -404,7 +428,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -404,7 +428,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType, BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
......
...@@ -309,9 +309,9 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -309,9 +309,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
auto a_block_buf_switch = a_block_buf; auto a_block_buf_switch = a_block_buf;
// preload data into LDS // preload data into LDS
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run( a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -364,6 +364,100 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -364,6 +364,100 @@ struct GridwiseGemmPipeline_v1<1, false, true>
template <> template <>
struct GridwiseGemmPipeline_v1<1, true, false> struct GridwiseGemmPipeline_v1<1, true, false>
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
}; };
template <> template <>
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename ADataType,
typename FloatB, typename BDataType,
typename FloatC, typename CDataType,
typename AGridDesc, typename AGridDesc,
typename BGridDesc_K0_N_K1, typename BGridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -33,11 +33,11 @@ __global__ void ...@@ -33,11 +33,11 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_wmma(const FloatA* __restrict__ p_a_grid, kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc, const AGridDesc a_grid_desc,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc b_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
p_c_grid, p_c_grid,
p_shared, p_shared,
a_grid_desc, a_grid_desc,
b_grid_desc_k0_n_k1, b_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc; ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -74,14 +74,14 @@ __global__ void ...@@ -74,14 +74,14 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename ADataType,
typename FloatB, typename BDataType,
typename FloatAcc, typename AccDataType,
typename FloatCShuffle, typename CShuffleDataType,
typename FloatC, typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc, typename AGridDesc,
typename BGridDesc_K0_N_K1, typename BGridDesc,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -181,6 +181,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -181,6 +181,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return a_block_desc; return a_block_desc;
} }
__host__ __device__ static constexpr auto MakeBBlockDescriptor()
{
constexpr auto b_block_desc = [&]() {
if constexpr(BEnableLds)
{
// K0->N->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto max_lds_align = K1;
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1));
}
}();
return b_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep() __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{ {
constexpr auto a_block_copy_step = [&]() { constexpr auto a_block_copy_step = [&]() {
...@@ -292,43 +326,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -292,43 +326,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return a_wave_desc; return a_wave_desc;
} }
template <typename BBlockDesc_BK0_N_BK1> template <typename BBlockDesc_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&) MakeBWaveDescriptor(const BBlockDesc_&)
{ {
constexpr auto B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0); constexpr auto b_wave_desc = [&]() {
constexpr auto B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2); if constexpr(BEnableLds)
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
constexpr auto K0PerBlock = KPerBlock / K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{ {
return make_naive_tensor_descriptor( // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( // KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<NRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
return b_block_desc_k0perblock_nperblock_k1; return b_wave_desc;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -349,7 +396,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -349,7 +396,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc& a_grid_desc, CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc& b_grid_desc,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
...@@ -378,17 +425,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -378,17 +425,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const auto GetBProblemsizeNK = [&]() { const auto GetBProblemsizeNK = [&]() {
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
return make_tuple(b_grid_desc_k0_n_k1.GetLength(I1), return make_tuple(b_grid_desc.GetLength(I1),
b_grid_desc_k0_n_k1.GetLength(I0) * b_grid_desc.GetLength(I0) *
b_grid_desc_k0_n_k1.GetLength(I2)); b_grid_desc.GetLength(I2));
} }
else else
{ {
return make_tuple( return make_tuple(
b_grid_desc_k0_n_k1.GetLength(I1) * b_grid_desc_k0_n_k1.GetLength(I2) * b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc_k0_n_k1.GetLength(I4), b_grid_desc.GetLength(I4),
b_grid_desc_k0_n_k1.GetLength(I0) * b_grid_desc_k0_n_k1.GetLength(I3) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc_k0_n_k1.GetLength(I5)); b_grid_desc.GetLength(I5));
} }
}; };
...@@ -484,9 +531,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -484,9 +531,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
max_lds_align) max_lds_align)
: 0; : 0;
static constexpr auto b_block_space_size_aligned = static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple( BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(), max_lds_align)
max_lds_align)
: 0; : 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
...@@ -500,18 +546,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -500,18 +546,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto c_shuffle_block_space_offset = 0; static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size = static constexpr auto lds_size =
math::max(c_shuffle_block_space_size * sizeof(FloatCShuffle), math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
a_block_space_size_aligned * sizeof(FloatA) + a_block_space_size_aligned * sizeof(ADataType) +
b_block_space_size_aligned * sizeof(FloatB)); b_block_space_size_aligned * sizeof(BDataType));
}; };
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid, __device__ static void Run(const ADataType* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc& a_grid_desc, const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc& b_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -525,7 +571,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -525,7 +571,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc.GetElementSpaceSize()); p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -554,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -554,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}(); }();
constexpr auto a_block_desc = MakeABlockDescriptor(); constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc = MakeBBlockDescriptor();
auto a_block_trait = [&](){ auto a_block_trait = [&](){
// A matrix blockwise copy // A matrix blockwise copy
...@@ -562,7 +608,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -562,7 +608,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared), static_cast<ADataType*>(p_shared),
SharedMemTrait::a_block_space_size_aligned); SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -573,8 +619,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -573,8 +619,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>, /* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA, /* typename SrcData, */ ADataType,
/* typename DstData, */ FloatA, /* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(a_grid_desc), /* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc), /* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
...@@ -601,13 +647,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -601,13 +647,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical // Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA, ThreadwiseTensorSliceTransfer_v2<ADataType,
FloatA, ADataType,
decltype(a_grid_desc), decltype(a_grid_desc),
decltype(a_block_desc), decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
...@@ -638,7 +684,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -638,7 +684,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
SharedMemTrait::b_block_space_size_aligned); SharedMemTrait::b_block_space_size_aligned);
auto b_blockwise_copy = auto b_blockwise_copy =
...@@ -649,9 +695,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -649,9 +695,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatB, BDataType,
FloatB, BDataType,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc),
decltype(b_block_desc), decltype(b_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
...@@ -663,7 +709,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -663,7 +709,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_grid_desc_k0_n_k1, b_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
b_block_desc, b_block_desc,
...@@ -674,23 +720,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -674,23 +720,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
} }
else else
{ {
constexpr auto K0PerBlock = KPerBlock/ K1; // Thread-wise copy
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v4<FloatB, ThreadwiseTensorSliceTransfer_v2<BDataType,
FloatB, BDataType,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc),
decltype(b_block_desc), decltype(b_block_desc),
Sequence<Number<K0PerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{}, Number<NRepeat>{},
I1,
I1,
I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2>, Sequence<0, 1, 2, 3, 4, 5>,
2, 5,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
1>( BThreadTransferSrcResetCoordinateAfterRun,
make_multi_index(0, get_thread_local_1d_id()/32 * 16 + get_thread_local_1d_id() % 16, 0)); true>(
b_grid_desc,
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b_block_buf, b_blockwise_copy); return make_tuple(b_block_buf, b_blockwise_copy);
} }
}; };
...@@ -706,11 +766,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -706,11 +766,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA<BlockSize, BlockwiseGemmWMMA<BlockSize,
FloatA, ADataType,
FloatB, BDataType,
FloatAcc, AccDataType,
decltype(MakeAWaveDescriptor(a_block_desc)), decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc)), decltype(MakeBWaveDescriptor(b_block_desc)),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -738,7 +798,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -738,7 +798,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_k0_n_k1, b_grid_desc,
b_block_desc, b_block_desc,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
...@@ -768,7 +828,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -768,7 +828,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
SharedMemTrait::c_shuffle_block_space_size); SharedMemTrait::c_shuffle_block_space_size);
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
...@@ -815,8 +875,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -815,8 +875,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
FloatCShuffle, CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -854,8 +914,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -854,8 +914,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, CShuffleDataType, // typename SrcData,
FloatC, // typename DstData, CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment