Commit 6c198a27 authored by Adam Osewski's avatar Adam Osewski
Browse files

Formatting & comments.

parent e4bffc94
......@@ -334,6 +334,23 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
......
......@@ -80,7 +80,6 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//?? what is this for
// sync with math threads()
block_sync_lds();
......@@ -95,8 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// tail
{
block_sync_lds();
// GEMM num_loop
// GEMM num_loop - 1
}
}
};
......
......@@ -173,10 +173,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& /*block_2_etile_map*/)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -187,7 +188,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K == b_grid_desc_n_k.GetLength(I1)))
{
return false;
}
......@@ -258,23 +260,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const auto M0 = M / M1;
const auto N0 = N / N1;
// FIXME: remove
constexpr auto M01 = I1;
constexpr auto N01 = I1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(make_unmerge_transform(make_tuple(M0, M01)),
make_unmerge_transform(make_tuple(N0, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
......@@ -375,7 +373,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
{
// build loadWave and MathWave pipelines
// loadWave and MathWave synchronized through LDS
//
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -527,6 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// TODO re-architect LDS+math stages
// Writing data to GMEM: only math wave is doing the work in cshuffle
GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
......@@ -594,7 +593,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
......@@ -704,14 +703,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
// Different way of getting coalesced writes:
// We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner
// do it interleaved, then mfma can have nice c-mat layout as below:
//
// TODO
// 1. we do not need to do LDS swizzle to align global writes writing cache
// lines
// We do not need to do LDS swizzle to align global writes writing cache lines:
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// elments (N is vertical or strided dimension) v_mfma cmat, bmat, amat,
// cmat - c-mat register layout are Mx1 elments (M is coalescing
// dimension) by enumerating M index in amat, bmat you can align cmat
// register(s) to contiguous M elements for example
// elments (N is vertical or strided
// dimension)
// v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1
// elments (M is coalescing
// dimension) by enumerating M index in
// amat, bmat you can align cmat
// register(s) to contiguous M elements
// for example
// 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 ....
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment