Commit f1e15d43 authored by wangshaojie6's avatar wangshaojie6
Browse files

revert gridwise gemm v2r4r2

parent a3b4c5cb
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "merge_transform_for_wrw.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
...@@ -110,9 +109,7 @@ template <index_t BlockSize, ...@@ -110,9 +109,7 @@ template <index_t BlockSize,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -124,10 +121,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -124,10 +121,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = Number<128>{};
static constexpr auto ElePerBank = Number<BankLength / sizeof(FloatAB)>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -138,37 +131,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -138,37 +131,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() { constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{
if constexpr(ABlockLdsExtraM1Wrw)
{
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(
Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1),
make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_block_desc_k0_m_k1_tmp;
}
else
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
}
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
...@@ -176,101 +145,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -176,101 +145,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_b_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
if constexpr(ABlockLdsExtraM1Wrw)
{
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{},
Number<K0PerBlock>{},
Number<M0PerBlock>{},
Number<M1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<M0PerBlock>{} *
(Number<M1PerBlock>{} * K1 + M1Padding),
Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_b_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return a_block_desc_b_k0_m_k1_tmp;
}
else
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
}
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
}
}();
return a_block_desc_b_k0_m_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() { constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN)
{
if constexpr(BBlockLdsExtraN1Wrw)
{
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(
Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_block_desc_k0_n_k1_tmp;
}
else
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
} }
}
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
...@@ -278,81 +160,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -278,81 +160,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
return b_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_b_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
if constexpr(BBlockLdsExtraN1Wrw)
{
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{},
Number<K0PerBlock>{},
Number<N0PerBlock>{},
Number<N1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<N0PerBlock>{} *
(Number<N1PerBlock>{} * K1 + N1Padding),
Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_b_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return b_block_desc_b_k0_n_k1_tmp;
}
else
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
}
}();
return b_block_desc_b_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = math::integer_least_multiple( constexpr auto a_block_space_size =
a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = math::integer_least_multiple( constexpr auto b_block_space_size =
b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size = constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
...@@ -497,13 +310,69 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -497,13 +310,69 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); constexpr auto a_b_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); constexpr auto b_b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
}
}();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -572,9 +441,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -572,9 +441,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// register // register
// sanity check // sanity check
constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatAB, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
...@@ -585,7 +451,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -585,7 +451,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -609,15 +475,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -609,15 +475,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
} }
// Initialize C // Initialize C
...@@ -627,43 +486,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -627,43 +486,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t k0_block_data_begin = 0; index_t k0_block_data_begin = 0;
block_sync_lds();
//do
//{
// blockwise_gemm.Run();
//
// block_sync_lds();
//
// a_blockwise_copy.MoveSrcSliceWindow();
// b_blockwise_copy.MoveSrcSliceWindow();
//
// a_blockwise_copy.RunWrite();
// b_blockwise_copy.RunWrite();
//
// a_blockwise_copy.RunRead();
// block_sync_lds();
// b_blockwise_copy.RunRead();
//
// k0 += K0PerBlock;
//} while(k0 < (K0 - K0PerBlock));
do do
{ {
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock; k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock)); } while(k0_block_data_begin < (K0 - K0PerBlock));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment