"examples/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "c44e985dc20ec79dcf4e64a9c1f6b8fa395d853b"
Commit 89123dd7 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 76df7392
......@@ -68,13 +68,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global,
......@@ -209,14 +209,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
......@@ -230,33 +230,42 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto a_block_slice_copy_step = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_step = Sequence<KPerBlock, 0>{};
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
__syncthreads();
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_odd);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_odd);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
......@@ -265,12 +274,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_even);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_even);
}
// LDS double buffer: tail
......@@ -282,8 +290,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
......@@ -296,15 +304,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
......@@ -433,13 +442,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global,
......@@ -584,14 +593,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
......@@ -603,15 +612,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
{
for(index_t k1 = 0; k1 < K1; ++k1)
{
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto a_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto b_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
......@@ -623,20 +631,20 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
......@@ -662,8 +670,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
......@@ -675,16 +683,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
a_blockwise_copy.RunStoreThreadBuffer(
p_a_thread_buffer, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(
p_b_thread_buffer, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space,
p_b_block_double + b_block_space,
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment