Commit 86580888 authored by raman jana's avatar raman jana
Browse files

fixes for global-write for math-wave

parent a607bc1a
...@@ -11,15 +11,15 @@ struct GridwiseGemmLoadWave; ...@@ -11,15 +11,15 @@ struct GridwiseGemmLoadWave;
template<typename TileLoadThreadGroup> template<typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{ {
// TODO: improve applicability // TODO: improve applicability
return num_loop % 2 == 0; return true;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop / 2 > 1; return num_loop > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -37,35 +37,29 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -37,35 +37,29 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
typename BBlockTransferStep> typename BBlockTransferStep>
static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc, static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy, ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf, ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy, BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
index_t num_loop) index_t num_loop)
{ {
// global read 0 // global read 0
a_block_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
//move to 1 //move to 1
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//LDS write 0 //LDS write 0
a_block_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
{ {
...@@ -75,43 +69,31 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -75,43 +69,31 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{ {
//sync for Load threads() //sync for Load threads()
block_sync_lds(); block_sync_lds();
// global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//?? what is this for //?? what is this for
// sync with math threads() // sync with math threads()
block_sync_lds(); block_sync_lds();
// move to i + 2 //LDS write i+1
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// LDS write i + 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
++i; ++i;
} while(i < (num_loop - 2)); } while(i < (num_loop - 1));
} }
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
//what is this for??
block_sync_lds();
// move to i + 2
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
// GEMM num_loop // GEMM num_loop
} }
...@@ -126,15 +108,14 @@ template <typename TileMathThreadGroup> ...@@ -126,15 +108,14 @@ template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1> struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{ {
// TODO: improve applicability return true;
return num_loop % 2 == 0;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop / 2 > 1; return num_loop > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -165,24 +146,16 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1> ...@@ -165,24 +146,16 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
block_sync_lds(); block_sync_lds();
++i; ++i;
} while(i < (num_loop - 2)); } while(i < (num_loop - 1));
} }
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 2
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
block_sync_lds();
// GEMM num_loop - 1 // GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
}; };
......
...@@ -137,10 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -137,10 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
} }
__device__ static constexpr bool IsBelong() __device__ static constexpr bool IsBelong()
{ {
return (get_thread_local_1d_id() < TileLoadThreadGroupSize); return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
} }
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } __device__ static index_t GetThreadId() { return get_thread_local_1d_id() - TileMathThreadGroupSize; }
}; };
...@@ -152,10 +152,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -152,10 +152,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
} }
__device__ static constexpr bool IsBelong() __device__ static constexpr bool IsBelong()
{ {
return get_thread_local_1d_id() >= TileLoadThreadGroupSize; return get_thread_local_1d_id() < TileMathThreadGroupSize;
} }
__device__ static index_t GetThreadId() { return get_thread_local_1d_id() - TileMathThreadGroupSize; } __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
}; };
using CShuffleBlockTransferThreadGroup = using CShuffleBlockTransferThreadGroup =
...@@ -476,11 +476,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -476,11 +476,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
num_k_block_main_loop); num_k_block_main_loop);
block_sync_lds();
block_sync_lds();
} }
else if (TileMathThreadGroup::IsBelong()) else if (TileMathThreadGroup::IsBelong())
{ {
//branch early for math wave //branch early for math wave
constexpr index_t KPack = math::max( constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
...@@ -507,7 +511,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -507,7 +511,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
...@@ -691,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -691,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
c_thread_buf, c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
block_sync_lds(); block_sync_lds();
...@@ -704,12 +706,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -704,12 +706,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment