Commit de87e1dc authored by ltqin's avatar ltqin
Browse files

v2r3r1 add double buffer

parent b560db68
...@@ -23,7 +23,8 @@ template <typename GridwiseGemm, ...@@ -23,7 +23,8 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -45,17 +46,18 @@ __global__ void ...@@ -45,17 +46,18 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
p_b_grid, p_a_grid,
p_c_grid, p_b_grid,
p_shared_block, p_c_grid,
a_grid_desc_k0_m_k1, p_shared_block,
b_grid_desc_k0_n_k1, a_grid_desc_k0_m_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, b_grid_desc_k0_n_k1,
a_element_op, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
b_element_op, a_element_op,
c_element_op, b_element_op,
block_2_ctile_map); c_element_op,
block_2_ctile_map);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -67,21 +69,23 @@ template <typename GridwiseGemm, ...@@ -67,21 +69,23 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap> typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r3r1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3r1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_grid_desc_k0_m_k1, const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1, const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op, const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op, const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op, const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map) const void CONSTANT* p_block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -104,17 +108,18 @@ __global__ void ...@@ -104,17 +108,18 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
p_b_grid, p_a_grid,
p_c_grid, p_b_grid,
p_shared_block, p_c_grid,
a_grid_desc_k0_m_k1, p_shared_block,
b_grid_desc_k0_n_k1, a_grid_desc_k0_m_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, b_grid_desc_k0_n_k1,
a_element_op, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
b_element_op, a_element_op,
c_element_op, b_element_op,
block_2_ctile_map); c_element_op,
block_2_ctile_map);
} }
#endif #endif
...@@ -219,7 +224,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -219,7 +224,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -281,6 +286,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -281,6 +286,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -371,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -371,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -520,8 +532,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -520,8 +532,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block; constexpr auto b_block_space_size =
FloatAB* p_b_block = p_shared_block + a_block_space_size; math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
...@@ -535,18 +550,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -535,18 +550,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); p_a_block_double, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); p_b_block_double, b_block_desc_k0_n_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double + a_block_space_size, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double + b_block_space_size, b_block_desc_k0_n_k1.GetElementSpaceSize());
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_even_buf);
} }
// main body // main body
...@@ -558,6 +578,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -558,6 +578,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
{ {
do do
{ {
// iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack); a_k0_m_k1_grid_move_slice_window_step_hack);
...@@ -565,30 +586,75 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -565,30 +586,75 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
b_block_slice_copy_step, b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack); b_k0_n_k1_grid_move_slice_window_step_hack);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// write data into odd buffer
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
// iteration for even
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
block_sync_lds(); block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm odd data
blockwise_gemm.Run(a_block_even_buf, b_block_odd_buf, c_thread_buf);
// write data into even buffer
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
k0_block_data_begin += 2 * K0PerBlock;
} while(k0_block_data_begin < (K0 - 2 * K0PerBlock));
}
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); // tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
// iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); // LDS double buffer: load last data from device mem
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// write data into odd buffer
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
k0_block_data_begin += K0PerBlock; block_sync_lds();
} while(k0_block_data_begin < (K0 - K0PerBlock)); // gemm odd data
blockwise_gemm.Run(a_block_even_buf, b_block_odd_buf, c_thread_buf);
} }
else
// tail
{ {
block_sync_lds(); block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment