Commit 04c6a978 authored by aska-0096's avatar aska-0096
Browse files

Skip B-Lds real gemm

parent f00dab9f
...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
2, // M Repeat 8, // M Repeat
4, // N-Repeat 1, // N-Repeat
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>, S<1, 16, 1, 16>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -106,12 +106,13 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -106,12 +106,13 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
} }
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
} }
#endif
}(); }();
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
...@@ -146,34 +147,57 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -146,34 +147,57 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
} }
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_n_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB)); make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), const auto b_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1)); make_tuple(StrideB, I1));
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
} }
}(); }();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0); assert(K % K1 == 0);
const index_t K0 = K / K1;
if constexpr(BEnableLds)
return transform_tensor_descriptor( {
b_grid_desc_n_k, const index_t K0 = K / K1;
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), return transform_tensor_descriptor(
make_tuple(Sequence<1>{}, Sequence<0>{}), b_grid_desc_n_k,
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -196,7 +220,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -196,7 +220,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
...@@ -209,7 +233,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -209,7 +233,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc, AGridDesc,
BGridDesc_K0_N_K1, BGridDesc,
CGridDesc_M_N, CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -281,7 +305,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -281,7 +305,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{ {
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA); a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ = b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
...@@ -301,7 +325,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -301,7 +325,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc a_grid_desc_; AGridDesc a_grid_desc_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock; c_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -371,7 +395,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -371,7 +395,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType, BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
...@@ -404,7 +428,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -404,7 +428,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType, BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
......
...@@ -309,9 +309,9 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -309,9 +309,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
auto a_block_buf_switch = a_block_buf; auto a_block_buf_switch = a_block_buf;
// preload data into LDS // preload data into LDS
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run( a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -364,6 +364,100 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -364,6 +364,100 @@ struct GridwiseGemmPipeline_v1<1, false, true>
template <> template <>
struct GridwiseGemmPipeline_v1<1, true, false> struct GridwiseGemmPipeline_v1<1, true, false>
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
}; };
template <> template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment