Commit 89123dd7 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 76df7392
...@@ -68,13 +68,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -68,13 +68,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{}); Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space = constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space = constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
} }
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run(const Float* __restrict__ p_a_global,
...@@ -209,14 +209,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -209,14 +209,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ThreadGemmBThreadCopySrcDataPerRead_N>{}; ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space = constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space = constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space; Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
...@@ -230,47 +230,55 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -230,47 +230,55 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
b_blockwise_copy.Run(p_b_global, p_b_block_double); b_blockwise_copy.Run(p_b_global, p_b_block_double);
} }
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{}; constexpr auto a_block_slice_copy_step = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{}; constexpr auto b_block_slice_copy_step = Sequence<KPerBlock, 0>{};
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
// LDS double buffer: main body // LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock) k_block_data_begin += 2 * KPerBlock)
{ {
#pragma unroll Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
for(index_t iloop = 0; iloop < 2; ++iloop) Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now = // even iteration
even_loop ? p_a_block_double : p_a_block_double + a_block_space; a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
Float* p_b_block_now = b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next = __syncthreads();
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; // LDS doubel buffer: load next data from device mem
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); // LDS double buffer: GEMM on current data
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
__syncthreads(); // LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_odd);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_odd);
// LDS doubel buffer: load next data from device mem // odd iteration
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
// LDS double buffer: GEMM on current data __syncthreads();
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
}
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_even);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_even);
} }
// LDS double buffer: tail // LDS double buffer: tail
...@@ -282,8 +290,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -282,8 +290,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads(); __syncthreads();
...@@ -296,15 +304,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -296,15 +304,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space); p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space); p_b_block_double + b_block_space_size);
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread); p_b_block_double + b_block_space_size,
p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
...@@ -433,13 +442,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 ...@@ -433,13 +442,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{}); Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space = constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space = constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
} }
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run(const Float* __restrict__ p_a_global,
...@@ -584,14 +593,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 ...@@ -584,14 +593,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
ThreadGemmBThreadCopySrcDataPerRead_N>{}; ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space = constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space = constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align); math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space; Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
...@@ -603,15 +612,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 ...@@ -603,15 +612,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
{ {
for(index_t k1 = 0; k1 < K1; ++k1) for(index_t k1 = 0; k1 < K1; ++k1)
{ {
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.Run(p_a_global, p_a_block_double); a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double); b_blockwise_copy.Run(p_b_global, p_b_block_double);
} }
constexpr auto a_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{}; constexpr auto a_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{}; constexpr auto b_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
// LDS double buffer: main body // LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
...@@ -623,20 +631,20 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 ...@@ -623,20 +631,20 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
const bool even_loop = (iloop % 2 == 0); const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now = Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space; even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now = Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space; even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next = Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double; even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next = Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double; even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads(); __syncthreads();
...@@ -662,8 +670,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 ...@@ -662,8 +670,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads(); __syncthreads();
...@@ -675,16 +683,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 ...@@ -675,16 +683,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, a_blockwise_copy.RunStoreThreadBuffer(
p_a_block_double + a_block_space); p_a_thread_buffer, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, b_blockwise_copy.RunStoreThreadBuffer(
p_b_block_double + b_block_space); p_b_thread_buffer, p_b_block_double + b_block_space_size);
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space, blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space, p_b_block_double + b_block_space_size,
p_c_thread); p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment