Commit 2bab820b authored by wangshaojie6's avatar wangshaojie6
Browse files

add depth for skip lds

parent 0d02519a
#pragma once #ifndef CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_SKIP_B_LDS_V1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
...@@ -111,7 +112,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -111,7 +112,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto MultiK0 = 2 * 1; static constexpr auto BaseMultK0 = 4;
static constexpr auto MultiK0 = BaseMultK0 * 1;
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -190,7 +192,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -190,7 +192,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// 2-stage prefetch currently only support even number of K0 loop // 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop // TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0)) if(!((K0 / K0PerBlock) % MultiK0 == 0))
{ {
return false; return false;
} }
...@@ -443,7 +445,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -443,7 +445,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
FloatAB, FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(), b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_1st_buf, b_thread_2nd_buf, b_thread_3rd_buf, b_thread_4th_buf;
const auto wave_id = GetWaveIdx(); const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
...@@ -527,7 +529,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -527,7 +529,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf); b_thread_1st_buf);
// Move // Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
...@@ -539,6 +541,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -539,6 +541,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// a data write to lds // a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// load 2nd a matrix data
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// main body // main body
if constexpr(HasMainK0BlockLoop) if constexpr(HasMainK0BlockLoop)
{ {
...@@ -551,29 +562,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -551,29 +562,53 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
blockwise_gemm.ResetABlockStartWindow(); blockwise_gemm.ResetABlockStartWindow();
block_sync_lds(); block_sync_lds();
static_for<0, MultiK0, 2>{}([&](auto) { static_for<0, MultiK0, BaseMultK0>{}([&](auto) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf); b_thread_3rd_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
// only move b windows
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf); b_thread_4th_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); s_nop();
// 3rd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_1st_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
// 4th
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
}); });
...@@ -592,35 +627,61 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -592,35 +627,61 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
block_sync_lds(); block_sync_lds();
blockwise_gemm.ResetABlockStartWindow(); blockwise_gemm.ResetABlockStartWindow();
static_for<0, MultiK0, 2>{}([&](auto i) { static_for<0, MultiK0, BaseMultK0>{}([&](auto i) {
// block_sync_lds(); // 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf); b_thread_3rd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
if constexpr(i < MultiK0 - 2) // 2nd
{ // only move b windows b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_4th_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf); b_thread_1st_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
} }
// block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(); blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
if constexpr(i < MultiK0 - 2) // 4th
if constexpr(i < MultiK0 - BaseMultK0)
{ {
blockwise_gemm.MoveABlockSliceWindow(); b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_2nd_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
} }
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
}); });
} }
} }
...@@ -706,3 +767,4 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -706,3 +767,4 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
}; };
} // namespace ck } // namespace ck
#endif
...@@ -17,5 +17,11 @@ __device__ void block_sync_lds() ...@@ -17,5 +17,11 @@ __device__ void block_sync_lds()
#endif #endif
} }
__device__ void s_nop(){
asm volatile("\
s_nop 0 \n \
" ::);
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment