Commit c64f63d5 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 20968472
......@@ -9,7 +9,7 @@
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh"
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh"
......@@ -418,8 +418,8 @@ int main()
device_direct_convolution_2
#elif 0
device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0
device_implicit_gemm_convolution_1_nchw_srck
#elif 1
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1
device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 0
......
#pragma once
#include "gridwise_implicit_gemm_convolution_1_nchw_srck.cuh"
#include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include <unistd.h>
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcsr,
OutDesc,
Tensor<T>& out_nkhw)
Tensor<T>& out_nkhw,
unsigned nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -101,6 +103,19 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128;
#elif 1
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 128;
#endif
......@@ -113,40 +128,46 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
cudaEvent_t start, stop;
float elapsedTime;
cudaEventCreate(&start);
cudaEventRecord(start, 0);
gridwise_implicit_gemm_convolution_1_nchw_srck<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_srck_desc),
decltype(out_nkhw_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread>
<<<grid_dim, block_dim>>>(in_nchw_desc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
wei_srck_desc,
static_cast<T*>(wei_srck_device_buf.GetDeviceBuffer()),
out_nkhw_desc,
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
cudaEventCreate(&stop);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime);
for(unsigned i = 0; i < nrepeat; ++i)
{
cudaEvent_t start, stop;
float elapsedTime;
cudaEventCreate(&start);
cudaEventRecord(start, 0);
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_srck_desc),
decltype(out_nkhw_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread>
<<<grid_dim, block_dim>>>(in_nchw_desc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
wei_srck_desc,
static_cast<T*>(wei_srck_device_buf.GetDeviceBuffer()),
out_nkhw_desc,
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
cudaEventCreate(&stop);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime);
usleep(10);
}
checkCudaErrors(cudaGetLastError());
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
......
......@@ -90,6 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned BPerBatch = 32;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
......@@ -134,7 +136,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
CPerBlock,
BPerThread,
KPerThread,
CPerThread>
CPerThread,
BPerBatch>
<<<grid_dim, block_dim>>>(in_cnhw_desc,
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
wei_srck_desc,
......
......@@ -22,7 +22,7 @@ template <unsigned GridSize,
unsigned HoPerThread,
unsigned WoPerThread>
__global__ void
gridwise_implicit_gemm_convolution_1_nchw_srck(InGlobalDesc,
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
Float* const __restrict__ p_in_global,
WeiGlobalDesc,
Float* const __restrict__ p_wei_global,
......
......@@ -19,7 +19,8 @@ template <unsigned GridSize,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread>
unsigned CPerThread,
unsigned BPerBatch>
__global__ void
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
Float* const __restrict__ p_in_global,
......@@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n");
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<BPerBlock>{},
Number<BPerBatch>{},
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
const auto blockwise_gemm =
const auto blockwise_batched_gemm =
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
......@@ -128,9 +131,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
false,
false,
0,
BPerBatch,
0,
0,
1,
BPerBlock/BPerBatch,
1,
CPerThread,
true>{};
......@@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
p_in_block + s * Wi + r,
p_out_thread,
f_accum);
......@@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
// output: register to global mem,
const auto matrix_c_index =
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment