Commit 84d9802d authored by Chao Liu's avatar Chao Liu
Browse files

adding implicit gemm

parent aa0199a3
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <cstdlib> #include <cstdlib>
#include "nvToolsExt.h" #include "nvToolsExt.h"
#include "tensor.hpp" #include "tensor.hpp"
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
#include "conv_common.cuh" #include "conv_common.cuh"
#include "device_direct_convolution_1.cuh" #include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh" #include "device_direct_convolution_2.cuh"
......
...@@ -27,7 +27,7 @@ void device_implicit_gemm_convolution( ...@@ -27,7 +27,7 @@ void device_implicit_gemm_convolution(
#if 1 #if 1
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 128; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2; constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32; constexpr unsigned WoPerBlock = 32;
......
#pragma once
#include "common.cuh"
template <unsigned NRow, unsigned NCol, unsigned RowStride>
struct ConstantMatrixDescriptor
{
__host__ __device__ ConstantMatrixDescriptor()
{
static_assert(NCol <= RowStride, "wrong! NCol > RowStride!");
}
__host__ __device__ constexpr unsigned GetNumberOfRow() const { return NRow; }
__host__ __device__ constexpr unsigned GetNumberOfColumn() const { return NCol; }
__host__ __device__ constexpr unsigned GetRowStride() const { return RowStride; }
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow * NCol; }
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow * RowStride; }
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
{
return irow * RowStride + icol;
}
template <unsigned SubNRow, unsigned SubNCol>
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) const
{
return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride>{};
}
};
template <unsigned NRow, unsigned NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}
template <unsigned NRow, unsigned NCol, unsigned RowStride>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}
#pragma once #pragma once
#include "common.cuh" #include "common.cuh"
template <class T, T N>
struct Constant
{
static const T mValue = N;
};
template <unsigned N>
using Number = Constant<unsigned, N>;
template <unsigned... Is>
struct Sequence
{
static constexpr unsigned nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Number<I>) const
{
return mData[I];
}
template <unsigned I0, unsigned I1>
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
return Sequence<IR0, IR1>{};
}
template <unsigned I0, unsigned I1, unsigned I2>
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
constexpr unsigned IR2 = Get(Number<I2>{});
return Sequence<IR0, IR1, IR2>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>, Number<I3>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
constexpr unsigned IR2 = Get(Number<I2>{});
constexpr unsigned IR3 = Get(Number<I3>{});
return Sequence<IR0, IR1, IR2, IR3>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto Reorder(Sequence<I0, I1, I2, I3>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
constexpr unsigned IR2 = Get(Number<I2>{});
constexpr unsigned IR3 = Get(Number<I3>{});
return Sequence<IR0, IR1, IR2, IR3>{};
}
};
template <class Lengths, class Strides> template <class Lengths, class Strides>
struct ConstantTensorDescriptor struct ConstantTensorDescriptor
{ {
......
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
#include "threadwise_tensor_op.cuh" #include "threadwise_tensor_op.cuh"
#include "threadwise_direct_convolution.cuh" #include "threadwise_direct_convolution.cuh"
......
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
template <unsigned BlockSize, class Float, class DstDesc, class F> template <unsigned BlockSize, class Float, class DstDesc, class F>
__device__ void __device__ void
......
...@@ -12,4 +12,72 @@ struct is_same<T, T> ...@@ -12,4 +12,72 @@ struct is_same<T, T>
static const bool value = true; static const bool value = true;
}; };
__device__ unsigned get_thread_local_id() { return threadIdx.x; } __device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
__device__ unsigned get_block_1d_id() { return blockIdx.x; }
template <class T, T N>
struct Constant
{
static const T mValue = N;
__host__ __device__ constexpr T Get() const { return mValue; }
};
template <unsigned N>
using Number = Constant<unsigned, N>;
template <unsigned... Is>
struct Sequence
{
static constexpr unsigned nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Number<I>) const
{
return mData[I];
}
template <unsigned I0, unsigned I1>
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
return Sequence<IR0, IR1>{};
}
template <unsigned I0, unsigned I1, unsigned I2>
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
constexpr unsigned IR2 = Get(Number<I2>{});
return Sequence<IR0, IR1, IR2>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>, Number<I3>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
constexpr unsigned IR2 = Get(Number<I2>{});
constexpr unsigned IR3 = Get(Number<I3>{});
return Sequence<IR0, IR1, IR2, IR3>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto Reorder(Sequence<I0, I1, I2, I3>) const
{
constexpr unsigned IR0 = Get(Number<I0>{});
constexpr unsigned IR1 = Get(Number<I1>{});
constexpr unsigned IR2 = Get(Number<I2>{});
constexpr unsigned IR3 = Get(Number<I3>{});
return Sequence<IR0, IR1, IR2, IR3>{};
}
};
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
// this is ugly, only for 4d // this is ugly, only for 4d
template <class InDesc, class WeiDesc> template <class InDesc, class WeiDesc>
......
#pragma once #pragma once
template <class ThreadMatrixA, template <class ThreadMatrixA,
bool TransA,
class FloatA,
class ThreadMatrixB, class ThreadMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB, bool TransB,
bool TransC,
class FloatA,
class FloatB, class FloatB,
class ThreadMatrixC,
class FloatC, class FloatC,
class Accumulator> class Accumulator>
__device__ void threadwise_gemm(ThreadMatrixA, __device__ void threadwise_gemm(ThreadMatrixA,
...@@ -26,41 +27,51 @@ __device__ void threadwise_gemm(ThreadMatrixA, ...@@ -26,41 +27,51 @@ __device__ void threadwise_gemm(ThreadMatrixA,
template <unsigned BlockSize, template <unsigned BlockSize,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
class ThreadMatrixC,
bool TransA, bool TransA,
bool TransB, bool TransB,
unsigned BatchSize, bool TransC,
unsigned BlockMatrixStrideA, unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB, unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned BatchPerThread, unsigned BatchPerThread,
unsigned MPerThread, unsigned KPerLoop,
unsigned NPerThread,
unsigned KPerThread,
class Accumulator> class Accumulator>
struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
{ {
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
struct MatrixIndex struct MatrixIndex
{ {
unsigned batch_begin; unsigned batch_begin;
unsigned block_row_begin; unsigned row_begin;
unsigned block_col_begin; unsigned col_begin;
}; };
__device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c() __device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c()
{ {
static_assert(ThreadMatrixStrideC > 0, "wrong! ThreadMatrixStrideC == 0!"); static_assert(ThreadMatrixStrideC > 0, "wrong! ThreadMatrixStrideC == 0!");
constexpr auto a_block = BlockMatrixA{}; #if 0
constexpr auto b_block = BlockMatrixB{}; constexpr auto a_block_desc = BlockMatrixA{};
constexpr auto b_block_desc = BlockMatrixB{};
constexpr auto a_thread = ThreadMatrixA{}; constexpr unsigned a_thread_row = (!TransA) ? MPerThread : KPerThread;
constexpr auto b_thread = ThreadMatrixB{}; constexpr unsigned a_thread_col = (!TransA) ? KPerThread : MPerThread;
constexpr auto c_thread = ThreadMatrixC{}; constexpr unsigned b_thread_row = (!TransB) ? KPerThread : NPerThread;
constexpr unsigned b_thread_col = (!TransB) ? NPerThread : KPerThread;
constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol(); constexpr auto a_thread_desc = ConstantMatrixDescriptor<a_thread_row, a_thread_col>{};
constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow(); constexpr auto b_thread_desc = ConstantMatrixDescriptor<b_thread_row, b_thread_col>{};
constexpr auto c_thread_desc = ConstantMatrixDescriptor<MPerThread, NPerThread>{};
constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol(); constexpr unsigned m_block = (!TransA) ? a_block_desc.NRow() : a_block_desc.NCol();
constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow(); constexpr unsigned n_block = (!TransB) ? b_block_desc.NCol() : b_block_desc.NRow();
constexpr unsigned m_thread = (!TransA) ? a_thread_desc.NRow() : a_thread_desc.NCol();
constexpr unsigned n_thread = (!TransB) ? b_thread_desc.NCol() : b_thread_desc.NRow();
constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread; constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread;
constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread; constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread;
...@@ -72,12 +83,17 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c ...@@ -72,12 +83,17 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id()); const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id());
mMyThreadOffsetA = xxx; // mMyThreadOffsetA = xxx;
mMyThreadoffSetB = xxx; // mMyThreadoffSetB = xxx;
#else
mMyThreadOffsetA = 0;
mMyThreadOffsetB = 0;
#endif
} }
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
{ {
#if 0
constexpr auto a_block = BlockMatrixA{}; constexpr auto a_block = BlockMatrixA{};
constexpr auto b_block = BlockMatrixB{}; constexpr auto b_block = BlockMatrixB{};
constexpr auto c_block = BlockMatrixC{}; constexpr auto c_block = BlockMatrixC{};
...@@ -104,6 +120,9 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c ...@@ -104,6 +120,9 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
return MatrixIndex{ return MatrixIndex{
batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread}; batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread};
#else
return MatrixIndex{0, 0, 0};
#endif
} }
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
...@@ -111,8 +130,4 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c ...@@ -111,8 +130,4 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
{ {
// do something // do something
} }
};
private:
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
}
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
#include "blockwise_tensor_op.cuh" #include "blockwise_tensor_op.cuh"
#include "blockwise_direct_convolution.cuh" #include "blockwise_direct_convolution.cuh"
......
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
#include "blockwise_tensor_op.cuh" #include "blockwise_tensor_op.cuh"
#include "blockwise_direct_convolution.cuh" #include "blockwise_direct_convolution.cuh"
#include "threadwise_tensor_op.cuh" #include "threadwise_tensor_op.cuh"
......
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_tensor_op.cuh" #include "blockwise_tensor_op.cuh"
#include "threadwise_tensor_op.cuh" #include "threadwise_tensor_op.cuh"
#include "gemm.cuh"
template <unsigned GridSize, template <unsigned GridSize,
unsigned BlockSize, unsigned BlockSize,
...@@ -45,59 +48,85 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -45,59 +48,85 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
constexpr unsigned HiPerBlock = HoPerBlock + S - 1; constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
constexpr unsigned WiPerBlock = WoPerBlock + R - 1; constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
// tensor view of blockwise input and weight in LDS // divide block work: NCHW
constexpr auto in_chwn_block_desc = constexpr unsigned NBlockWork =
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}); (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork =
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
unsigned itmp = get_block_1d_id();
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin;
const unsigned wi_block_data_begin = wo_block_data_begin;
// tensor view of blockwise input and weight in LDS
constexpr auto wei_srck_block_desc = constexpr auto wei_srck_block_desc =
make_ConstantTensorDescriptor(Sequence<S, R, CPerBlock, KPerBlock>{}); make_ConstantTensorDescriptor(Sequence<S, R, CPerBlock, KPerBlock>{});
// matrix view of blockwise input and weight in LDS constexpr auto in_chwn_block_desc =
constexpr auto in_cxhwn_block_mtx_desc = make_ConstantMatrixDescriptor( make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
Number<CPerBlock>, Number<HiPerBlock * WiPerBlock * NPerBlock>);
constexpr auto wei_srcxk_block_mtx_desc = // tensor view of threadwise output in register
make_ConstantMatrixDescriptor(Number<S * R * CPerBlock>, Number<KPerBlock>); constexpr auto out_hkwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
// LDS // a series of blockwise batched GEMM
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); // C_matrix += transpose(A_matrix) * B_matrix
constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); // A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
__shared__ Float p_in_block[in_block_size]; const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor(
__shared__ Float p_wei_block[wei_block_size]; Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_chwn_block_desc.GetStride(I1)>{}); // constexpr doesn't compile
// a series of batched GEMM const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor(
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
constexpr auto a_block_mtx_desc =
wei_srcxk_block_mtx_desc.MakeSubMatrixDescriptor(Number<CPerBlock>{}, Number<KPerBlock>{});
constexpr auto b_block_mtx_desc = in_cxhwn_block_mtx_desc.MakeSubMatrixDescriptor( auto f_accum = [](auto& c, auto& ab) { c += ab; };
Number<CPerBlock>{}, Number<WoPerBlock * NPerBlock>{});
auto f_accum = (auto& c, auto& v) { c += v; };
const auto blockwise_batch_gemm = const auto blockwise_batch_gemm =
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize, blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
a_block_mtx_desc, decltype(a_cxk_block_mtx_desc),
b_block_mtx_desc, decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
true, true,
false, false,
HoPerBlock, false,
0, 0,
xxx_b_matrix_stride, in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(
I1),
HoPerBlock,
HoPerThread, HoPerThread,
KPerThread, CPerThread,
NPerThread * WoPerThread,
CPerTrhead,
decltype(f_accum)>{}; decltype(f_accum)>{};
// tensor view of threadwise output in register // LDS
constexpr auto out_hkwn_thread_desc = constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{}); constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size];
// register // register
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
...@@ -105,15 +134,19 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -105,15 +134,19 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1); for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// input: global mem to LDS, // input: global mem to LDS,
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
constexpr auto reorder_nchw2chwn = Sequence<3, 0, 1, 2>{}; constexpr auto reorder_nchw2chwn = Sequence<3, 0, 1, 2>{};
blockwise_4d_tensor_copy_reorder<BlockSize>(in_nchw_global_desc, blockwise_4d_tensor_copy_reorder<BlockSize>(
p_in_global, in_nchw_global_desc,
p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin,
hi_block_data_begin,
wi_block_data_begin),
in_chwn_block_desc, in_chwn_block_desc,
p_in_block, p_in_block,
in_chwn_block_desc, in_chwn_block_desc,
...@@ -123,30 +156,31 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -123,30 +156,31 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K] // convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{}; constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{};
blockwise_4d_tensor_copy_reorder<BlockSize>(wei_csrk_global_desc, blockwise_4d_tensor_copy_reorder<BlockSize>(
p_wei_global, wei_kcsr_global_desc,
wei_csrk_block_desc, p_wei_global +
wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
wei_srck_block_desc,
p_wei_block, p_wei_block,
wei_csrk_block_desc, wei_srck_block_desc,
reorder_kcsr2csrk); reorder_kcsr2srck);
__syncthreads(); __syncthreads();
// loop over filter point // a series of batched GEMM
for(unsigned s = 0; s < S; ++s) for(unsigned s = 0; s < S; ++s)
{ {
for(unsigned r = 0; r < R; ++r) for(unsigned r = 0; r < R; ++r)
{ {
blockwise_batch_gemm.run( blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
p_wei_block + wei_srcxk_block_mtx_desc.Get1dIndex(xxxxx, xxxx), p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0),
p_in_block + in_cxhwn_block_mtx_desc.Get1dIndex(xxxx, xxxx),
p_out_thread); p_out_thread);
} }
} }
} }
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_id()); blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.col_begin; const unsigned k_thread_data_begin = matrix_c_index.col_begin;
...@@ -160,7 +194,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, ...@@ -160,7 +194,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
out_hkwn_thread_desc, out_hkwn_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.GetIndex(n_block_data_begin, p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
#include "blockwise_winograd_transform.cuh" #include "blockwise_winograd_transform.cuh"
#include "threadwise_winograd_transform.cuh" #include "threadwise_winograd_transform.cuh"
......
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
// optimized for scenario if p_in, p_wei, p_out are in register // optimized for scenario if p_in, p_wei, p_out are in register
template <class Float, class InDesc, class WeiDesc, class OutDesc> template <class Float, class InDesc, class WeiDesc, class OutDesc>
......
#pragma once #pragma once
#include "constant_tensor_descriptor.cuh" #include "ConstantTensorDescriptor.cuh"
template <class Float, class Desc, class F> template <class Float, class Desc, class F>
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f) __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment