Commit 4841d991 authored by Adam Osewski's avatar Adam Osewski
Browse files

Multiple changes to gridwise gemm.

parent ad0e4083
...@@ -441,7 +441,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -441,7 +441,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t K, const index_t K,
const index_t StrideA, const index_t StrideA,
const index_t StrideB, const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs, [[maybe_unused]] const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE, const index_t StrideE,
const index_t KBatch) const index_t KBatch)
{ {
...@@ -449,10 +449,117 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -449,10 +449,117 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 = const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch);
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
ignore = StrideDs; if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(M % MPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE); #endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(N % NPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = KBatch * KPerBlock;
if(!(K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of ! KBatch * KPerBlock: " << K << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(K % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(M % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(N % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(K % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = (a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * const auto num_k_loop = (a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
...@@ -461,6 +568,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -461,6 +568,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
#if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline."
<< " K0Padded: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false; return false;
} }
...@@ -524,7 +637,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -524,7 +637,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__host__ __device__ static auto __host__ __device__ static auto
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(index_t grid_size) MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(index_t grid_size)
{ {
const auto w_desc_grid_i1_mperb_nperb = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
...@@ -537,9 +649,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -537,9 +649,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple(grid_size, I1.value, MPerBlock, NPerBlock), make_tuple(grid_size, I1.value, MPerBlock, NPerBlock),
make_tuple(MPerBlock * NPerBlock, MPerBlock * NPerBlock, I1.value, MPerBlock)); make_tuple(MPerBlock * NPerBlock, MPerBlock * NPerBlock, I1.value, MPerBlock));
} }
}();
return w_desc_grid_i1_mperb_nperb;
} }
// TODO: we should refactor out all those common Make... descriptors to sth like // TODO: we should refactor out all those common Make... descriptors to sth like
...@@ -700,73 +809,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -700,73 +809,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// TODO Need to do CShuffle already here: // TODO Need to do CShuffle already here:
__device__ void StorePartials(void* __restrict__ p_workspace) __device__ void StorePartials(void* __restrict__ p_workspace)
{ {
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const auto workspace_grid_desc_m0_n0_m1_n1 =
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock(get_grid_size());
const auto w_grid_m0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I0);
const auto w_grid_n0 = workspace_grid_desc_m0_n0_m1_n1.GetLength(I1);
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1.GetElementSpaceSize());
const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer(); const auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2;
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// M0 = grid_size -> MRepeats auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
// N0 = 1 -> NRepeats auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 = transform_tensor_descriptor( p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
workspace_grid_desc_m0_n0_m1_n1,
make_tuple(make_pass_through_transform(w_grid_m0),
make_pass_through_transform(w_grid_n0),
make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4, 6, 7, 8>{}, Sequence<3, 5, 9>{}));
const auto workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
make_tuple(make_merge_transform(make_tuple(w_grid_m0, M0)), // MRepeats (grid) constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
make_merge_transform(make_tuple(w_grid_n0, N0)), // NRepeats (grid) constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
make_pass_through_transform(M1), // MWave constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
make_pass_through_transform(N1), // NWave constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
make_pass_through_transform(M2), // mfma_instr.num_groups_per_blk constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
make_pass_through_transform(M3), // mfma_instr.num_input_blks constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
make_pass_through_transform(M4), // mfma_instr.group_size
make_pass_through_transform(N2)), // mfma_instr.num_threads_per_blk
make_tuple(Sequence<0, 2>{},
Sequence<1, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -810,7 +873,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -810,7 +873,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{// DstResetCoordinateAfterRun true>{// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(static_cast<index_t>(blockIdx.x) * MXdlPerWave, make_multi_index(m_thread_data_on_block_idx[I0],
n_thread_data_on_block_idx[I0], n_thread_data_on_block_idx[I0],
m_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1],
...@@ -827,14 +890,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -827,14 +890,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
w_grid_buf); w_grid_buf);
} }
__device__ void AccumulatePartials(void* __restrict__ p_workspace, index_t reduce_count) __device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count)
{ {
auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer(); auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); BlockwiseGemmT::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// using CThreadBufferT = ck::remove_reference_t<decltype(c_thread_buf)>;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, AccDataType,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(),
...@@ -957,10 +1019,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -957,10 +1019,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using Accumulation = using Accumulation =
ck::detail::AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>; ck::detail::AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
constexpr auto partial_acc_load_step =
make_multi_index(MXdlPerWave, I0, I0, I0, I0, I0, I0, I0);
// We do not need to read this workgroup partial results since they're // We do not need to read this workgroup partial results since they're
// already in c_thread_buff // already in c_thread_buff
for(int i_t = 1; i_t < reduce_count; ++i_t) for(uint32_t i_t = 1; i_t < reduce_count; ++i_t)
{ {
acc_buf.Clear(); acc_buf.Clear();
acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
...@@ -971,6 +1035,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -971,6 +1035,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_for<0, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), 1>{}( static_for<0, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), 1>{}(
[&](auto i_vec) { Accumulation::Calculate(c_thread_buf(i_vec), acc_buf[i_vec]); }); [&](auto i_vec) { Accumulation::Calculate(c_thread_buf(i_vec), acc_buf[i_vec]); });
acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
partial_acc_load_step);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment