Commit 7faf269c authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 03eef73c
#pragma once #pragma once
#include "common.hip.hpp" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_direct_convolution.hip.hpp" #include "blockwise_direct_convolution.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp"
...@@ -20,10 +21,12 @@ template <class Float, ...@@ -20,10 +21,12 @@ template <class Float,
unsigned CPerThread, unsigned CPerThread,
unsigned HoPerThread, unsigned HoPerThread,
unsigned WoPerThread, unsigned WoPerThread,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead,
unsigned BlockSize, unsigned BlockSize,
unsigned GridSize> unsigned GridSize>
__global__ void __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) Float* const __restrict__ p_out_global)
{ {
...@@ -32,50 +35,72 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -32,50 +35,72 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto in_global_desc = InGlobalDesc{}; constexpr auto in_nchw_global_desc = InGlobalDesc{};
constexpr auto wei_global_desc = WeiGlobalDesc{}; constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
constexpr auto out_global_desc = OutGlobalDesc{}; constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
constexpr unsigned Y = wei_global_desc.GetLength(I2); constexpr unsigned N = in_nchw_global_desc.GetLength(I0);
constexpr unsigned X = wei_global_desc.GetLength(I3); constexpr unsigned K = wei_kcyx_global_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_global_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_global_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_global_desc.GetLength(I3);
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_block_desc = constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}); Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<KPerBlock, CPerBlock * Y * X>{},
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
constexpr auto wei_block_desc = constexpr auto wei_kcyx_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{}); make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem // shared mem
constexpr unsigned in_block_size = in_block_desc.GetElementSpace(); constexpr unsigned in_block_size =
constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace(); in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[in_block_size]; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[wei_block_size]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// threadwise tensors // threadwise tensors
constexpr unsigned HiPerThread = HoPerThread + Y - 1; constexpr unsigned HiPerThread = HoPerThread + Y - 1;
constexpr unsigned WiPerThread = WoPerThread + X - 1; constexpr unsigned WiPerThread = WoPerThread + X - 1;
constexpr auto in_thread_block_desc = make_ConstantTensorDescriptor( constexpr auto in_nchw_thread_block_desc =
Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{}, in_block_desc.GetStrides()); make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
in_nchw_block_desc.GetStrides());
constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor( constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, Y, X>{}, wei_block_desc.GetStrides()); Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
constexpr auto out_thread_desc = get_convolution_output_default_4d_tensor_descriptor( constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
in_thread_block_desc, wei_thread_block_desc); in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
// register // register
Float p_out_thread[out_thread_desc.GetElementSpace()]; Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work // divide block work
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr unsigned NBlockWork =
constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; constexpr unsigned KBlockWork =
constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const unsigned block_id = blockIdx.x; const unsigned block_id = blockIdx.x;
...@@ -122,33 +147,44 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -122,33 +147,44 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
constexpr auto blockwise_in_copy = constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
Float, Float,
decltype(in_global_desc), decltype(in_nchw_global_desc),
decltype(in_block_desc), decltype(in_nchw_block_desc),
decltype(in_block_desc.GetLengths())>{}; decltype(in_nchw_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#if 0
constexpr auto blockwise_wei_copy = constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
Float, Float,
decltype(wei_global_desc), decltype(wei_kcyx_global_desc),
decltype(wei_block_desc), decltype(wei_kcyx_block_desc),
decltype(wei_block_desc.GetLengths())>{}; decltype(wei_kcyx_block_desc.GetLengths()),
1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ke_global_desc),
decltype(wei_ke_block_desc),
decltype(wei_ke_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1); for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_data_begin, blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin, c_block_data_begin,
hi_block_data_begin, hi_block_data_begin,
wi_block_data_begin), wi_block_data_begin),
p_in_block); p_in_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run( blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex(
p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block); p_wei_block);
__syncthreads(); __syncthreads();
...@@ -158,25 +194,27 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -158,25 +194,27 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
// threadwise convolution // threadwise convolution
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
wei_thread_block_desc, wei_kcyx_thread_block_desc,
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), p_wei_block +
out_thread_desc, wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread); p_out_thread);
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
wei_thread_block_desc, wei_kcyx_thread_block_desc,
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), p_wei_block +
out_thread_desc, wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread); p_out_thread);
#endif #endif
} }
...@@ -184,12 +222,12 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -184,12 +222,12 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
// copy output tensor from register to global mem // copy output tensor from register to global mem
threadwise_4d_tensor_copy( threadwise_4d_tensor_copy(
out_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_global_desc, out_nkhw_global_desc,
p_out_global + out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
out_thread_desc.GetLengths()); out_nkhw_thread_desc.GetLengths());
} }
...@@ -158,7 +158,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -158,7 +158,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
Float, Float,
decltype(wei_kcyx_global_desc), decltype(wei_kcyx_global_desc),
decltype(wei_kcyx_block_desc), decltype(wei_kcyx_block_desc),
decltype(wei_kcyx_block_desc.GetLengths())>{}; decltype(wei_kcyx_block_desc.GetLengths()),
1>{};
#elif 1 #elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize, const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float, Float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment