"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "f40942ad630c7e261c3cff5aa06be20e7e486d5b"
Commit 28325204 authored by Chao Liu's avatar Chao Liu
Browse files

adding fp16 direct that reads pre-vectorized data

parent 4f0fc72e
......@@ -27,8 +27,8 @@ template <class Float,
unsigned BlockSize,
unsigned GridSize>
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_in_global,
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_wei_global,
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_in_vec_global,
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_wei_vec_global,
Float* const __restrict__ p_out_global)
{
using scalar_t = Float;
......@@ -76,25 +76,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
__shared__ vector_t p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ vector_t p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// threadwise tensors
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
constexpr unsigned WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_thread_block_desc =
constexpr auto in_nchw_vec_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
in_nchw_vec_block_desc.GetStrides());
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
constexpr auto wei_kcyx_vec_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_vec_block_desc.GetStrides());
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc);
// register
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work
constexpr unsigned NBlockWork =
......@@ -150,7 +150,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
vector_t,
decltype(in_nchw_vec_global_desc),
decltype(in_nchw_vec_block_desc),
decltype(in_nchw_vec_block_desc.GetLengths()),
......@@ -159,7 +159,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
vector_t,
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
......@@ -167,7 +167,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 1
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
vector_t,
decltype(wei_ke_vec_global_desc),
decltype(wei_ke_vec_block_desc),
decltype(wei_ke_vec_block_desc.GetLengths()),
......@@ -181,16 +181,16 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
c_block_data_begin += CPerBlock, __syncthreads())
{
// copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_nchw_vec_global_desc.Get1dIndex(n_block_data_begin,
blockwise_in_copy.Run(p_in_vec_global + in_nchw_vec_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin,
hi_block_data_begin,
wi_block_data_begin),
p_in_block);
p_in_vec_block);
// copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_vec_global_desc.Get1dIndex(
blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex(
k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block);
p_wei_vec_block);
__syncthreads();
......@@ -199,25 +199,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
// threadwise convolution
#if 1
threadwise_direct_convolution_2(
in_nchw_thread_block_desc,
p_in_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_thread_block_desc,
p_wei_block +
wei_kcyx_vec_thread_block_desc,
p_wei_vec_block +
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread);
#elif 0
threadwise_direct_convolution_3(
in_nchw_thread_block_desc,
p_in_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_vec_thread_block_desc,
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_thread_block_desc,
p_wei_block +
wei_kcyx_vec_thread_block_desc,
p_wei_vec_block +
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment