"example/vscode:/vscode.git/clone" did not exist on "0bf57502f260e12ff9ca2cff78424c36c8a0a507"
Commit 9d8b39a7 authored by Chao Liu's avatar Chao Liu
Browse files

update gridwise gemm to use DynamicBuffer

parent df83690d
...@@ -23,7 +23,11 @@ template <typename GridwiseGemm, ...@@ -23,7 +23,11 @@ template <typename GridwiseGemm,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc, __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
const FloatA* __restrict__ p_a_global, const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc, const BGlobalDesc b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
...@@ -55,7 +59,11 @@ template <typename GridwiseGemm, ...@@ -55,7 +59,11 @@ template <typename GridwiseGemm,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global, const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
...@@ -177,6 +185,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -177,6 +185,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer(p_a_global);
const auto b_global_buf = make_dynamic_buffer(p_b_global);
auto c_global_buf = make_dynamic_buffer(p_c_global);
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_global_desc.GetLength(I1);
...@@ -353,25 +365,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -353,25 +365,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double; auto a_block_even_buf = make_dynamic_buffer(p_a_block_double);
FloatAB* p_b_block_even = p_b_block_double; auto b_block_even_buf = make_dynamic_buffer(p_b_block_double);
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size; auto a_block_odd_buf = make_dynamic_buffer(p_a_block_double + a_block_space_size);
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size; auto b_block_odd_buf = make_dynamic_buffer(p_b_block_double + b_block_space_size);
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -394,16 +400,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -394,16 +400,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
...@@ -417,16 +423,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -417,16 +423,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock); } while(k_block_data_begin < K - 2 * KPerBlock);
...@@ -445,15 +451,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -445,15 +451,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
__syncthreads(); __syncthreads();
...@@ -502,7 +508,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -502,7 +508,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
} }
} }
......
...@@ -84,8 +84,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -84,8 +84,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
auto a_global_buf = make_dynamic_buffer(p_a_global); const auto a_global_buf = make_dynamic_buffer(p_a_global);
auto b_global_buf = make_dynamic_buffer(p_b_global); const auto b_global_buf = make_dynamic_buffer(p_b_global);
auto c_global_buf = make_dynamic_buffer(p_c_global); auto c_global_buf = make_dynamic_buffer(p_c_global);
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
......
...@@ -92,7 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -740,7 +740,7 @@ int main(int argc, char* argv[]) ...@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment