Commit 9d8b39a7 authored by Chao Liu's avatar Chao Liu
Browse files

update gridwise gemm to use DynamicBuffer

parent df83690d
......@@ -23,7 +23,11 @@ template <typename GridwiseGemm,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
......@@ -55,7 +59,11 @@ template <typename GridwiseGemm,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
......@@ -177,6 +185,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer(p_a_global);
const auto b_global_buf = make_dynamic_buffer(p_b_global);
auto c_global_buf = make_dynamic_buffer(p_c_global);
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
......@@ -353,25 +365,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_double);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_double);
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_double + a_block_space_size);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_double + b_block_space_size);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
......@@ -394,16 +400,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
......@@ -417,16 +423,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
......@@ -445,15 +451,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
__syncthreads();
......@@ -502,7 +508,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
......
......@@ -84,8 +84,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
auto a_global_buf = make_dynamic_buffer(p_a_global);
auto b_global_buf = make_dynamic_buffer(p_b_global);
const auto a_global_buf = make_dynamic_buffer(p_a_global);
const auto b_global_buf = make_dynamic_buffer(p_b_global);
auto c_global_buf = make_dynamic_buffer(p_c_global);
constexpr auto E = EPerBlock * 3 * 3;
......
......@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 540;
......@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0
#if 1
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
......@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 0
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment