Commit 4841d991 authored by Adam Osewski's avatar Adam Osewski
Browse files

Multiple changes to gridwise gemm.

parent ad0e4083
......@@ -441,7 +441,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
[[maybe_unused]] const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch)
{
......@@ -449,10 +449,117 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch);
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
ignore = StrideDs;
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(M % MPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(N % NPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = KBatch * KPerBlock;
if(!(K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of ! KBatch * KPerBlock: " << K << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(K % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(M % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(N % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(K % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = (a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
......@@ -461,6 +568,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
#if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline."
<< " K0Padded: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
......@@ -524,7 +637,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__host__ __device__ static auto
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(index_t grid_size)
{
const auto w_desc_grid_i1_mperb_nperb = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(
......@@ -537,9 +649,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple(grid_size, I1.value, MPerBlock, NPerBlock),
make_tuple(MPerBlock * NPerBlock, MPerBlock * NPerBlock, I1.value, MPerBlock));
}
}();
return w_desc_grid_i1_mperb_nperb;
}
// TODO: we should refactor out all those common Make... descriptors to sth like
......@@ -700,73 +809,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// TODO Need to do CShuffle already here:
__device__ void StorePartials(void* __restrict__ p_workspace)
{
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const auto workspace_grid_desc_m0_n0_m1_n1 =
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(get_grid_size());
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1.GetElementSpaceSize());
const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2;
// M0 = grid_size -> MRepeats
// N0 = 1 -> NRepeats
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
make_pass_through_transform(w_grid_n0),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4, 6, 7, 8>{}, Sequence<3, 5, 9>{}));
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(make_merge_transform(make_tuple(w_grid_m0, M0)), // MRepeats (grid)
make_merge_transform(make_tuple(w_grid_n0, N0)), // NRepeats (grid)
make_pass_through_transform(M1), // MWave
make_pass_through_transform(N1), // NWave
make_pass_through_transform(M2), // mfma_instr.num_groups_per_blk
make_pass_through_transform(M3), // mfma_instr.num_input_blks
make_pass_through_transform(M4), // mfma_instr.group_size
make_pass_through_transform(N2)), // mfma_instr.num_threads_per_blk
make_tuple(Sequence<0, 2>{},
Sequence<1, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
......@@ -810,7 +873,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
1, // DstScalarStrideInVector
true>{// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(static_cast<index_t>(blockIdx.x) * MXdlPerWave,
make_multi_index(m_thread_data_on_block_idx[I0],
n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
......@@ -827,14 +890,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
w_grid_buf);
}
__device__ void AccumulatePartials(void* __restrict__ p_workspace, index_t reduce_count)
__device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count)
{
auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// using CThreadBufferT = ck::remove_reference_t<decltype(c_thread_buf)>;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(),
......@@ -957,10 +1019,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using Accumulation =
ck::detail::AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
constexpr auto partial_acc_load_step =
make_multi_index(MXdlPerWave, I0, I0, I0, I0, I0, I0, I0);
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for(int i_t = 1; i_t < reduce_count; ++i_t)
for(uint32_t i_t = 1; i_t < reduce_count; ++i_t)
{
acc_buf.Clear();
acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......@@ -971,6 +1035,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_for<0, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), 1>{}(
[&](auto i_vec) { Accumulation::Calculate(c_thread_buf(i_vec), acc_buf[i_vec]); });
acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
partial_acc_load_step);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment