"client_example/19_pool/max_pool2d_bwd.cpp" did not exist on "76ec0089fb254c221138ec8cb2962ba5056f7fa9"
Commit de87e1dc authored by ltqin's avatar ltqin
Browse files

v2r3r1 add double buffer

parent b560db68
......@@ -23,7 +23,8 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -45,17 +46,18 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
......@@ -67,21 +69,23 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap>
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3r1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -104,17 +108,18 @@ __global__ void
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
GridwiseGemm::template Run<HasMainKBlockLoop, HasDoubleTailKBlockLoop>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif
......@@ -219,7 +224,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto b_block_space_size =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -281,6 +286,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
return has_main_k0_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
......@@ -371,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainKBlockLoop>
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -520,8 +532,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto a_block_space_size =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto b_block_space_size =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......@@ -535,18 +550,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double, b_block_desc_k0_n_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double + a_block_space_size, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double + b_block_space_size, b_block_desc_k0_n_k1.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_even_buf);
}
// main body
......@@ -558,6 +578,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
{
do
{
// iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
......@@ -565,30 +586,75 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// write data into odd buffer
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
// iteration for even
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm odd data
blockwise_gemm.Run(a_block_even_buf, b_block_odd_buf, c_thread_buf);
// write data into even buffer
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
k0_block_data_begin += 2 * K0PerBlock;
} while(k0_block_data_begin < (K0 - 2 * K0PerBlock));
}
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
// iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
block_sync_lds();
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// write data into odd buffer
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
block_sync_lds();
// gemm odd data
blockwise_gemm.Run(a_block_even_buf, b_block_odd_buf, c_thread_buf);
}
// tail
else
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment