Commit c64f63d5 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 20968472
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "device_direct_convolution_1.cuh" #include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh" #include "device_direct_convolution_2.cuh"
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" #include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck.cuh" #include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh" //#include "device_winograd_convolution.cuh"
...@@ -418,8 +418,8 @@ int main() ...@@ -418,8 +418,8 @@ int main()
device_direct_convolution_2 device_direct_convolution_2
#elif 0 #elif 0
device_implicit_gemm_convolution_1_nchw_kcsr device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0 #elif 1
device_implicit_gemm_convolution_1_nchw_srck device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_2_cnhw_srck_knhw device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 0 #elif 0
......
#pragma once #pragma once
#include "gridwise_implicit_gemm_convolution_1_nchw_srck.cuh" #include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include <unistd.h>
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_1_nchw_srck(InDesc, void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcsr, const Tensor<T>& wei_kcsr,
OutDesc, OutDesc,
Tensor<T>& out_nkhw) Tensor<T>& out_nkhw,
unsigned nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -101,6 +103,19 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc, ...@@ -101,6 +103,19 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#endif #endif
...@@ -113,40 +128,46 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc, ...@@ -113,40 +128,46 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
cudaEvent_t start, stop; for(unsigned i = 0; i < nrepeat; ++i)
float elapsedTime; {
cudaEvent_t start, stop;
cudaEventCreate(&start); float elapsedTime;
cudaEventRecord(start, 0);
cudaEventCreate(&start);
gridwise_implicit_gemm_convolution_1_nchw_srck<GridSize, cudaEventRecord(start, 0);
BlockSize,
T, gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw<GridSize,
decltype(in_nchw_desc), BlockSize,
decltype(wei_srck_desc), T,
decltype(out_nkhw_desc), decltype(in_nchw_desc),
NPerBlock, decltype(wei_srck_desc),
KPerBlock, decltype(out_nkhw_desc),
CPerBlock, NPerBlock,
HoPerBlock, KPerBlock,
WoPerBlock, CPerBlock,
KPerThread, HoPerBlock,
CPerThread, WoPerBlock,
HoPerThread, KPerThread,
WoPerThread> CPerThread,
<<<grid_dim, block_dim>>>(in_nchw_desc, HoPerThread,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), WoPerThread>
wei_srck_desc, <<<grid_dim, block_dim>>>(in_nchw_desc,
static_cast<T*>(wei_srck_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
out_nkhw_desc, wei_srck_desc,
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(wei_srck_device_buf.GetDeviceBuffer()),
out_nkhw_desc,
cudaEventCreate(&stop); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop); cudaEventCreate(&stop);
cudaEventRecord(stop, 0);
cudaEventElapsedTime(&elapsedTime, start, stop); cudaEventSynchronize(stop);
printf("Elapsed time : %f ms\n", elapsedTime);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime);
usleep(10);
}
checkCudaErrors(cudaGetLastError()); checkCudaErrors(cudaGetLastError());
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
......
...@@ -90,6 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -90,6 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr unsigned CPerBlock = 2;
constexpr unsigned BPerBatch = 32;
constexpr unsigned BPerThread = 4; constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
...@@ -134,7 +136,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -134,7 +136,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
CPerBlock, CPerBlock,
BPerThread, BPerThread,
KPerThread, KPerThread,
CPerThread> CPerThread,
BPerBatch>
<<<grid_dim, block_dim>>>(in_cnhw_desc, <<<grid_dim, block_dim>>>(in_cnhw_desc,
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
wei_srck_desc, wei_srck_desc,
......
...@@ -22,7 +22,7 @@ template <unsigned GridSize, ...@@ -22,7 +22,7 @@ template <unsigned GridSize,
unsigned HoPerThread, unsigned HoPerThread,
unsigned WoPerThread> unsigned WoPerThread>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_1_nchw_srck(InGlobalDesc, gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
Float* const __restrict__ p_in_global, Float* const __restrict__ p_in_global,
WeiGlobalDesc, WeiGlobalDesc,
Float* const __restrict__ p_wei_global, Float* const __restrict__ p_wei_global,
......
...@@ -19,7 +19,8 @@ template <unsigned GridSize, ...@@ -19,7 +19,8 @@ template <unsigned GridSize,
unsigned CPerBlock, unsigned CPerBlock,
unsigned BPerThread, unsigned BPerThread,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread> unsigned CPerThread,
unsigned BPerBatch>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
Float* const __restrict__ p_in_global, Float* const __restrict__ p_in_global,
...@@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n");
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<CPerBlock>{},
Number<BPerBlock>{}, Number<BPerBatch>{},
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
const auto blockwise_gemm = const auto blockwise_batched_gemm =
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize, blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
decltype(a_cxk_block_mtx_desc), decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc), decltype(b_cxb_block_mtx_desc),
...@@ -128,9 +131,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -128,9 +131,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
false, false,
false, false,
0, 0,
BPerBatch,
0, 0,
0, BPerBlock/BPerBatch,
1,
1, 1,
CPerThread, CPerThread,
true>{}; true>{};
...@@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{ {
auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
p_in_block + s * Wi + r, p_in_block + s * Wi + r,
p_out_thread, p_out_thread,
f_accum); f_accum);
...@@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
// output: register to global mem, // output: register to global mem,
const auto matrix_c_index = const auto matrix_c_index =
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.col_begin; const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment