Commit 6c198a27 authored by Adam Osewski's avatar Adam Osewski
Browse files

Formatting & comments.

parent e4bffc94
...@@ -334,6 +334,23 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -334,6 +334,23 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
......
...@@ -80,7 +80,6 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -80,7 +80,6 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//?? what is this for
// sync with math threads() // sync with math threads()
block_sync_lds(); block_sync_lds();
...@@ -95,8 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -95,8 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 1
// GEMM num_loop
} }
} }
}; };
......
...@@ -173,10 +173,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -173,10 +173,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap> template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k, const BGridDesc_N_K& b_grid_desc_n_k,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& /*block_2_etile_map*/)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -187,7 +188,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -187,7 +188,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K == b_grid_desc_n_k.GetLength(I1)))
{ {
return false; return false;
} }
...@@ -258,23 +260,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -258,23 +260,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
// FIXME: remove
constexpr auto M01 = I1; constexpr auto M01 = I1;
constexpr auto N01 = I1; constexpr auto N01 = I1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)), make_tuple(make_unmerge_transform(make_tuple(M0, M01)),
make_unmerge_transform(make_tuple(N00, N01))), make_unmerge_transform(make_tuple(N0, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -375,7 +373,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -375,7 +373,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
{ {
// build loadWave and MathWave pipelines // build loadWave and MathWave pipelines
// loadWave and MathWave synchronized through LDS // loadWave and MathWave synchronized through LDS
//
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -527,6 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -527,6 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// TODO re-architect LDS+math stages // TODO re-architect LDS+math stages
// Writing data to GMEM: only math wave is doing the work in cshuffle
GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>( GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop); a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
...@@ -704,14 +703,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -704,14 +703,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
// Different way of getting coalesced writes:
// We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner
// do it interleaved, then mfma can have nice c-mat layout as below:
//
// TODO // TODO
// 1. we do not need to do LDS swizzle to align global writes writing cache // We do not need to do LDS swizzle to align global writes writing cache lines:
// lines
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// elments (N is vertical or strided dimension) v_mfma cmat, bmat, amat, // elments (N is vertical or strided
// cmat - c-mat register layout are Mx1 elments (M is coalescing // dimension)
// dimension) by enumerating M index in amat, bmat you can align cmat // v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1
// register(s) to contiguous M elements for example // elments (M is coalescing
// dimension) by enumerating M index in
// amat, bmat you can align cmat
// register(s) to contiguous M elements
// for example
// 1st mfma instruction output space : 0 4 8 12 16 .... // 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 .... // 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 .... // 3rd mfma instruction output space : 2 6 10 14 18 ....
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment