Commit a21b0d27 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 6790b8f3
...@@ -178,7 +178,7 @@ int main() ...@@ -178,7 +178,7 @@ int main()
for(int i = 0; i < 20; ++i) for(int i = 0; i < 20; ++i)
{ {
device_direct_convolution_2(in_desc, in, wei_desc, wei, out_desc, out_device); device_direct_convolution_1(in_desc, in, wei_desc, wei, out_desc, out_device);
} }
#if 0 #if 0
......
...@@ -55,23 +55,23 @@ void device_direct_convolution_1( ...@@ -55,23 +55,23 @@ void device_direct_convolution_1(
cudaEventCreate(&start); cudaEventCreate(&start);
cudaEventRecord(start, 0); cudaEventRecord(start, 0);
gridwise_convolution<T, gridwise_direct_convolution_1<T,
InDesc, InDesc,
WeiDesc, WeiDesc,
OutDesc, OutDesc,
OutTileSizeH, OutTileSizeH,
OutTileSizeW, OutTileSizeW,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
CPerBlock, CPerBlock,
YPerBlock, YPerBlock,
XPerBlock, XPerBlock,
NBlockOpLen0, NBlockOpLen0,
NBlockOpLen1, NBlockOpLen1,
NBlockOpLen2, NBlockOpLen2,
NBlockOpLen3, NBlockOpLen3,
BlockSize, BlockSize,
GridSize> GridSize>
<<<grid_dim, block_dim>>>(InDesc{}, <<<grid_dim, block_dim>>>(InDesc{},
static_cast<T*>(in_device_buf.GetDeviceBuffer()), static_cast<T*>(in_device_buf.GetDeviceBuffer()),
WeiDesc{}, WeiDesc{},
......
...@@ -59,26 +59,26 @@ void device_direct_convolution_2( ...@@ -59,26 +59,26 @@ void device_direct_convolution_2(
cudaEventCreate(&start); cudaEventCreate(&start);
cudaEventRecord(start, 0); cudaEventRecord(start, 0);
gridwise_convolution<T, gridwise_direct_convolution_2<T,
InDesc, InDesc,
WeiDesc, WeiDesc,
OutDesc, OutDesc,
OutTileSizeH, OutTileSizeH,
OutTileSizeW, OutTileSizeW,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
CPerBlock, CPerBlock,
YPerBlock, YPerBlock,
XPerBlock, XPerBlock,
NPerThread, NPerThread,
KPerThread, KPerThread,
CPerThread, CPerThread,
NBlockOpLen0, NBlockOpLen0,
NBlockOpLen1, NBlockOpLen1,
NBlockOpLen2, NBlockOpLen2,
NBlockOpLen3, NBlockOpLen3,
BlockSize, BlockSize,
GridSize> GridSize>
<<<grid_dim, block_dim>>>(InDesc{}, <<<grid_dim, block_dim>>>(InDesc{},
static_cast<T*>(in_device_buf.GetDeviceBuffer()), static_cast<T*>(in_device_buf.GetDeviceBuffer()),
WeiDesc{}, WeiDesc{},
......
...@@ -20,12 +20,12 @@ template <class TFloat, ...@@ -20,12 +20,12 @@ template <class TFloat,
unsigned NBlockOpLen3, unsigned NBlockOpLen3,
unsigned BlockSize, unsigned BlockSize,
unsigned GridSize> unsigned GridSize>
__global__ void gridwise_convolution(InGlobalDesc, __global__ void gridwise_direct_convolution_1(InGlobalDesc,
TFloat* const __restrict__ p_in_global, TFloat* const __restrict__ p_in_global,
WeiGlobalDesc, WeiGlobalDesc,
TFloat* const __restrict__ p_wei_global, TFloat* const __restrict__ p_wei_global,
OutGlobalDesc, OutGlobalDesc,
TFloat* __restrict__ p_out_global) TFloat* __restrict__ p_out_global)
{ {
constexpr auto I0 = Index<0>{}; constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{}; constexpr auto I1 = Index<1>{};
......
...@@ -25,12 +25,12 @@ template <class TFloat, ...@@ -25,12 +25,12 @@ template <class TFloat,
unsigned NBlockOpLen3, unsigned NBlockOpLen3,
unsigned BlockSize, unsigned BlockSize,
unsigned GridSize> unsigned GridSize>
__global__ void gridwise_convolution(InGlobalDesc, __global__ void gridwise_direct_convolution_2(InGlobalDesc,
TFloat* const __restrict__ p_in_global, TFloat* const __restrict__ p_in_global,
WeiGlobalDesc, WeiGlobalDesc,
TFloat* const __restrict__ p_wei_global, TFloat* const __restrict__ p_wei_global,
OutGlobalDesc, OutGlobalDesc,
TFloat* __restrict__ p_out_global) TFloat* __restrict__ p_out_global)
{ {
constexpr auto I0 = Index<0>{}; constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{}; constexpr auto I1 = Index<1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment