"vscode:/vscode.git/clone" did not exist on "3727d00bf2d4fe4547b292921efb62603417298f"
Commit 1eafc9c1 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent fee92fb6
......@@ -302,7 +302,7 @@ template <class T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
{
float error = 0;
float max_diff = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
for(int i = 0; i < ref.mData.size(); ++i)
{
......@@ -338,6 +338,14 @@ int main()
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
constexpr unsigned N = 72;
constexpr unsigned C = 288;
constexpr unsigned HI = 38;
constexpr unsigned WI = 38;
constexpr unsigned K = 72;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
constexpr unsigned N = 1;
constexpr unsigned C = 1;
......@@ -347,13 +355,13 @@ int main()
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
constexpr unsigned N = 1;
constexpr unsigned C = 1;
constexpr unsigned N = 1;
constexpr unsigned C = 1;
constexpr unsigned HI = 4;
constexpr unsigned WI = 4;
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
constexpr unsigned N = 2;
constexpr unsigned C = 3;
......
......@@ -26,13 +26,13 @@ void device_direct_convolution_1(
constexpr auto out_desc = OutDesc{};
constexpr unsigned OutTileSizeH = 2;
constexpr unsigned OutTileSizeW = 2;
constexpr unsigned NPerBlock = 1;
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 16;
constexpr unsigned CPerBlock = 4;
constexpr unsigned YPerBlock = 4;
constexpr unsigned CPerBlock = 2;
constexpr unsigned YPerBlock = 2;
constexpr unsigned XPerBlock = 16;
constexpr unsigned NPerThread = 1;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
......
......@@ -41,8 +41,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock;
constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock;
constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;
constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;
......@@ -102,11 +102,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
constexpr auto wei_transform_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, InTileSizeH, InTileSizeW>{});
constexpr unsigned in_transform_block_size = in_transform_block_desc.GetElementSpace();
constexpr unsigned wei_transform_block_size = wei_transform_block_desc.GetElementSpace();
__shared__ TFloat p_in_transform_block[in_transform_block_size];
__shared__ TFloat p_wei_transform_block[wei_transform_block_size];
__shared__ TFloat p_in_transform_block[in_transform_block_desc.GetElementSpace()];
__shared__ TFloat p_wei_transform_block[wei_transform_block_desc.GetElementSpace()];
// thread data
constexpr auto in_transform_thread_block_desc =
......@@ -126,11 +123,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
constexpr auto out_thread_global_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides());
constexpr unsigned out_transform_thread_size = out_transform_thread_desc.GetElementSpace();
constexpr unsigned out_thread_size = out_thread_desc.GetElementSpace();
TFloat p_out_transform_thread[out_transform_thread_size];
TFloat p_out_thread[out_thread_size];
TFloat p_out_transform_thread[out_transform_thread_desc.GetElementSpace()];
TFloat p_out_thread[out_thread_desc.GetElementSpace()];
#if 0
if(blockIdx.x == 0 && threadIdx.x == 0)
......
......@@ -116,10 +116,13 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, TFloat* __restrict__ p, ID
const unsigned did0_end =
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - shift : desc.GetLength(I0);
const unsigned did1_end =
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - shift : desc.GetLength(I1);
const unsigned did2_end =
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - shift : desc.GetLength(I2);
const unsigned did3_end =
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - shift : desc.GetLength(I3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment