Commit f6934e0b authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 5096a157
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
#include "blockwise_convolution.cuh" #include "blockwise_convolution.cuh"
template <class TFloat, template <class TFloat,
class InDesc, class InGlobalDesc,
class WeiDesc, class WeiGlobalDesc,
class OutDesc, class OutGlobalDesc,
unsigned OutTileSizeH, unsigned OutTileSizeH,
unsigned OutTileSizeW, unsigned OutTileSizeW,
unsigned NPerBlock, unsigned NPerBlock,
...@@ -20,24 +20,24 @@ template <class TFloat, ...@@ -20,24 +20,24 @@ template <class TFloat,
unsigned NBlockOpLen3, unsigned NBlockOpLen3,
unsigned BlockSize, unsigned BlockSize,
unsigned GridSize> unsigned GridSize>
__global__ void gridwise_convolution(InDesc, __global__ void gridwise_convolution(InGlobalDesc,
TFloat* const __restrict__ p_in_glb, TFloat* const __restrict__ p_in_global,
WeiDesc, WeiGlobalDesc,
TFloat* const __restrict__ p_wei_glb, TFloat* const __restrict__ p_wei_global,
OutDesc, OutGlobalDesc,
TFloat* __restrict__ p_out_glb) TFloat* __restrict__ p_out_global)
{ {
constexpr auto I0 = Index<0>{}; constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{}; constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{}; constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{}; constexpr auto I3 = Index<3>{};
constexpr auto in_desc = InDesc{}; constexpr auto in_global_desc = InGlobalDesc{};
constexpr auto wei_desc = WeiDesc{}; constexpr auto wei_global_desc = WeiGlobalDesc{};
constexpr auto out_desc = OutDesc{}; constexpr auto out_global_desc = OutGlobalDesc{};
constexpr unsigned S = wei_desc.GetLength(I2); constexpr unsigned S = wei_global_desc.GetLength(I2);
constexpr unsigned R = wei_desc.GetLength(I3); constexpr unsigned R = wei_global_desc.GetLength(I3);
constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock; constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock;
constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock; constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock;
...@@ -45,34 +45,34 @@ __global__ void gridwise_convolution(InDesc, ...@@ -45,34 +45,34 @@ __global__ void gridwise_convolution(InDesc,
constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1; constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1; constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;
constexpr unsigned NBlockWork = (out_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork = (out_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned YBlockWork = (out_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; constexpr unsigned YBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned XBlockWork = (out_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; constexpr unsigned XBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
constexpr auto in_block_glb_desc = make_ConstantTensorDescriptor( constexpr auto in_block_src_desc = make_ConstantTensorDescriptor(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_desc.GetStrides()); Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides());
constexpr auto wei_block_glb_desc = make_ConstantTensorDescriptor( constexpr auto wei_block_src_desc = make_ConstantTensorDescriptor(
Sequence<KPerBlock, CPerBlock, S, R>{}, wei_desc.GetStrides()); Sequence<KPerBlock, CPerBlock, S, R>{}, wei_global_desc.GetStrides());
constexpr auto out_block_glb_desc = make_ConstantTensorDescriptor( constexpr auto out_block_src_desc = make_ConstantTensorDescriptor(
Sequence<NPerBlock, KPerBlock, HoPerBlock, WoPerBlock>{}, out_desc.GetStrides()); Sequence<NPerBlock, KPerBlock, HoPerBlock, WoPerBlock>{}, out_global_desc.GetStrides());
constexpr auto in_block_lds_desc = constexpr auto in_block_dst_desc =
make_ConstantTensorDescriptor(in_block_glb_desc.GetLengths()); make_ConstantTensorDescriptor(in_block_src_desc.GetLengths());
constexpr auto wei_block_lds_desc = constexpr auto wei_block_dst_desc =
make_ConstantTensorDescriptor(wei_block_glb_desc.GetLengths()); make_ConstantTensorDescriptor(wei_block_src_desc.GetLengths());
constexpr auto out_block_lds_desc = constexpr auto out_block_dst_desc =
make_ConstantTensorDescriptor(out_block_glb_desc.GetLengths()); make_ConstantTensorDescriptor(out_block_src_desc.GetLengths());
constexpr unsigned in_block_size = in_block_lds_desc.GetElementSpace(); constexpr unsigned in_block_size = in_block_dst_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_block_lds_desc.GetElementSpace(); constexpr unsigned wei_block_size = wei_block_dst_desc.GetElementSpace();
constexpr unsigned out_block_size = out_block_lds_desc.GetElementSpace(); constexpr unsigned out_block_size = out_block_dst_desc.GetElementSpace();
__shared__ TFloat p_in_block_lds[in_block_size]; __shared__ TFloat p_in_block[in_block_size];
__shared__ TFloat p_wei_block_lds[wei_block_size]; __shared__ TFloat p_wei_block[wei_block_size];
__shared__ TFloat p_out_block_lds[out_block_size]; __shared__ TFloat p_out_block[out_block_size];
const unsigned block_id = blockIdx.x; const unsigned block_id = blockIdx.x;
...@@ -98,15 +98,15 @@ __global__ void gridwise_convolution(InDesc, ...@@ -98,15 +98,15 @@ __global__ void gridwise_convolution(InDesc,
#if 0 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
print_ConstantTensorDescriptor( in_desc, "gridwise_convolution: in_desc: "); print_ConstantTensorDescriptor( in_global_desc, "gridwise_convolution: in_global_desc: ");
print_ConstantTensorDescriptor(wei_desc, "gridwise_convolution: wei_desc: "); print_ConstantTensorDescriptor(wei_global_desc, "gridwise_convolution: wei_global_desc: ");
print_ConstantTensorDescriptor(out_desc, "gridwise_convolution: out_desc: "); print_ConstantTensorDescriptor(out_global_desc, "gridwise_convolution: out_global_desc: ");
print_ConstantTensorDescriptor( in_block_glb_desc, "gridwise_convolution: in_block_glb_desc: "); print_ConstantTensorDescriptor( in_block_src_desc, "gridwise_convolution: in_block_src_desc: ");
print_ConstantTensorDescriptor(wei_block_glb_desc, "gridwise_convolution: wei_block_glb_desc: "); print_ConstantTensorDescriptor(wei_block_src_desc, "gridwise_convolution: wei_block_src_desc: ");
print_ConstantTensorDescriptor(out_block_glb_desc, "gridwise_convolution: out_block_glb_desc: "); print_ConstantTensorDescriptor(out_block_src_desc, "gridwise_convolution: out_block_src_desc: ");
print_ConstantTensorDescriptor( in_block_lds_desc, "gridwise_convolution: in_block_lds_desc: "); print_ConstantTensorDescriptor( in_block_dst_desc, "gridwise_convolution: in_block_dst_desc: ");
print_ConstantTensorDescriptor(wei_block_lds_desc, "gridwise_convolution: wei_block_lds_desc: "); print_ConstantTensorDescriptor(wei_block_dst_desc, "gridwise_convolution: wei_block_dst_desc: ");
print_ConstantTensorDescriptor(out_block_lds_desc, "gridwise_convolution: out_block_lds_desc: "); print_ConstantTensorDescriptor(out_block_dst_desc, "gridwise_convolution: out_block_dst_desc: ");
printf("NBlockWork %u, KBlockWork %u, YBlockWork %u, XBlockWork %u \t" printf("NBlockWork %u, KBlockWork %u, YBlockWork %u, XBlockWork %u \t"
"block_id %u, n_block_work_id %u, k_block_work_id %u, y_block_work_id %u, " "block_id %u, n_block_work_id %u, k_block_work_id %u, y_block_work_id %u, "
...@@ -129,51 +129,52 @@ __global__ void gridwise_convolution(InDesc, ...@@ -129,51 +129,52 @@ __global__ void gridwise_convolution(InDesc,
// set output tensor in LDS to 0 // set output tensor in LDS to 0
blockwise_4d_tensor_op_unary<TFloat, blockwise_4d_tensor_op_unary<TFloat,
decltype(out_block_lds_desc), decltype(out_block_dst_desc),
NBlockOpLen0, NBlockOpLen0,
NBlockOpLen1, NBlockOpLen1,
NBlockOpLen2, NBlockOpLen2,
NBlockOpLen3, NBlockOpLen3,
decltype(f_set0), decltype(f_set0),
BlockSize>(out_block_lds_desc, p_out_block_lds, f_set0); BlockSize>(out_block_dst_desc, p_out_block, f_set0);
for(unsigned c_block_work_begin = 0; c_block_work_begin < in_desc.GetLength(I1); for(unsigned c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
c_block_work_begin += CPerBlock) c_block_work_begin += CPerBlock)
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_4d_tensor_op_binary<TFloat, blockwise_4d_tensor_op_binary<TFloat,
decltype(in_block_glb_desc), decltype(in_block_src_desc),
decltype(in_block_lds_desc), decltype(in_block_dst_desc),
NBlockOpLen0, NBlockOpLen0,
NBlockOpLen1, NBlockOpLen1,
NBlockOpLen2, NBlockOpLen2,
NBlockOpLen3, NBlockOpLen3,
decltype(f_copy), decltype(f_copy),
BlockSize>( BlockSize>(
in_block_glb_desc, in_block_src_desc,
p_in_glb + in_block_glb_desc.Get1dIndex(n_block_work_begin, p_in_global + in_block_src_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin, c_block_work_begin,
hi_block_work_begin, hi_block_work_begin,
wi_block_work_begin), wi_block_work_begin),
in_block_lds_desc, in_block_dst_desc,
p_in_block_lds, p_in_block,
f_copy); f_copy);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_4d_tensor_op_binary<TFloat, blockwise_4d_tensor_op_binary<TFloat,
decltype(wei_block_glb_desc), decltype(wei_block_src_desc),
decltype(wei_block_lds_desc), decltype(wei_block_dst_desc),
NBlockOpLen0, NBlockOpLen0,
NBlockOpLen1, NBlockOpLen1,
NBlockOpLen2, NBlockOpLen2,
NBlockOpLen3, NBlockOpLen3,
decltype(f_copy), decltype(f_copy),
BlockSize>( BlockSize>(
wei_block_glb_desc, wei_block_src_desc,
p_wei_glb + wei_block_glb_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0), p_wei_global +
wei_block_lds_desc, wei_block_src_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
p_wei_block_lds, wei_block_dst_desc,
p_wei_block,
f_copy); f_copy);
#if 1 #if 1
...@@ -182,17 +183,17 @@ __global__ void gridwise_convolution(InDesc, ...@@ -182,17 +183,17 @@ __global__ void gridwise_convolution(InDesc,
// blockwise convolution // blockwise convolution
blockwise_convolution<TFloat, blockwise_convolution<TFloat,
decltype(in_block_lds_desc), decltype(in_block_dst_desc),
decltype(wei_block_lds_desc), decltype(wei_block_dst_desc),
decltype(out_block_lds_desc), decltype(out_block_dst_desc),
OutTileSizeH, OutTileSizeH,
OutTileSizeW, OutTileSizeW,
BlockSize>(in_block_lds_desc, BlockSize>(in_block_dst_desc,
p_in_block_lds, p_in_block,
wei_block_lds_desc, wei_block_dst_desc,
p_wei_block_lds, p_wei_block,
out_block_lds_desc, out_block_dst_desc,
p_out_block_lds); p_out_block);
#if 1 #if 1
__syncthreads(); __syncthreads();
...@@ -201,19 +202,19 @@ __global__ void gridwise_convolution(InDesc, ...@@ -201,19 +202,19 @@ __global__ void gridwise_convolution(InDesc,
// copy output tensor from LDS to device mem // copy output tensor from LDS to device mem
blockwise_4d_tensor_op_binary<TFloat, blockwise_4d_tensor_op_binary<TFloat,
decltype(out_block_lds_desc), decltype(out_block_dst_desc),
decltype(out_block_glb_desc), decltype(out_block_src_desc),
NBlockOpLen0, NBlockOpLen0,
NBlockOpLen1, NBlockOpLen1,
NBlockOpLen2, NBlockOpLen2,
NBlockOpLen3, NBlockOpLen3,
decltype(f_copy), decltype(f_copy),
BlockSize>( BlockSize>(
out_block_lds_desc, out_block_dst_desc,
p_out_block_lds, p_out_block,
out_block_glb_desc, out_block_src_desc,
p_out_glb + p_out_global +
out_block_glb_desc.Get1dIndex( out_block_src_desc.Get1dIndex(
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin), n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin),
f_copy); f_copy);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment