Commit 1cb98850 authored by Chao Liu's avatar Chao Liu
Browse files

add anther verision of batch gemm

parent 9f2e8f8b
......@@ -75,6 +75,39 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 1
// for 3x3, 34x34, try
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned BlockSize = 128;
#elif 0
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
......@@ -131,7 +164,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 1
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
......@@ -184,7 +217,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
......@@ -212,13 +250,23 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
Sequence<InBlockCopy_ThreadPerDimC,
InBlockCopy_ThreadPerDimH,
InBlockCopy_ThreadPerDimW,
InBlockCopy_ThreadPerDimN>,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>,
WeiBlockCopyDataPerRead,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>,
dim3(GridSize),
dim3(BlockSize),
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
......
......@@ -391,7 +391,7 @@ int main()
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
#elif 0
#elif 1
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
......@@ -593,11 +593,11 @@ int main()
device_implicit_gemm_convolution_1_nchw_kcsr_nkhw
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 0
#elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#elif 1
#elif 0
device_implicit_gemm_convolution_2_chwn_csrk_khwn
#endif
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
......
......@@ -453,8 +453,7 @@ struct Blockwise2dTensorCopy3
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
auto f_copy = [&](unsigned iloop) {
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] =
......@@ -476,6 +475,11 @@ struct Blockwise2dTensorCopy3
{
assert(false);
}
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
......@@ -486,29 +490,7 @@ struct Blockwise2dTensorCopy3
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] =
p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else
{
assert(false);
}
f_copy(nloop_d0);
}
}
}
......@@ -561,8 +543,7 @@ struct Blockwise2dTensorCopy3
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
auto f_copy = [&](unsigned iloop) {
if(DataPerRead == 1)
{
p_clipboard[iloop] = p_src[mSrcMyThreadOffset + iloop * src_loop_stride];
......@@ -583,6 +564,11 @@ struct Blockwise2dTensorCopy3
{
assert(false);
}
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
......@@ -593,26 +579,7 @@ struct Blockwise2dTensorCopy3
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
if(DataPerRead == 1)
{
p_clipboard[nloop_d0] = p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_clipboard + nloop_d0 * 2)) =
*(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_clipboard + nloop_d0 * 4)) =
*(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride));
}
else
{
assert(false);
}
f_copy(nloop_d0);
}
}
}
......@@ -649,8 +616,7 @@ struct Blockwise2dTensorCopy3
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
auto f_copy = [&](unsigned iloop) {
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] = p_clipboard[iloop];
......@@ -669,6 +635,11 @@ struct Blockwise2dTensorCopy3
{
assert(false);
}
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
......@@ -679,26 +650,7 @@ struct Blockwise2dTensorCopy3
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
if(DataPerRead == 1)
{
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] = p_clipboard[nloop_d0];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<const Float2*>(p_clipboard + nloop_d0 * 2));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<const Float4*>(p_clipboard + nloop_d0 * 4));
}
else
{
assert(false);
}
f_copy(nloop_d0);
}
}
}
......
......@@ -337,3 +337,175 @@ struct BlockwiseChwnTensorCopyPadded
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template <unsigned BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
class ThreadPerDims,
unsigned DataPerRead>
struct Blockwise4dTensorCopy3
{
unsigned mSrcMyThreadOffset;
unsigned mDstMyThreadOffset;
__device__ Blockwise4dTensorCopy3()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1,
"wrong! only support stride3 == 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(
SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
DstDesc{}.GetStride(I2) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n");
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0 && L2 % thread_per_d2 == 0,
"wrong! L0, L1, L2 should be divided evenly!\n");
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
"wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr unsigned num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
const unsigned thread_id_d0 =
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
unsigned itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const unsigned thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
const unsigned thread_id_d2 = itmp / thread_per_d3;
const unsigned thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
mDstMyThreadOffset = DstDesc{}.Get1dIndex(
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
using Float2 = float2;
using Float4 = float4;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr unsigned num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr unsigned nloop_d1 = L1 / thread_per_d1;
constexpr unsigned nloop_d2 = L2 / thread_per_d2;
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
#pragma unroll
for(unsigned iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(unsigned iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(unsigned iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
#pragma unroll
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{
const unsigned src_offset =
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
const unsigned dst_offset =
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
if(DataPerRead == 1)
{
p_dst[dst_offset + mDstMyThreadOffset] =
p_src[src_offset + mSrcMyThreadOffset];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + dst_offset + mDstMyThreadOffset)) =
*(reinterpret_cast<const Float2*>(p_src + src_offset +
mSrcMyThreadOffset));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + dst_offset + mDstMyThreadOffset)) =
*(reinterpret_cast<const Float4*>(p_src + src_offset +
mSrcMyThreadOffset));
}
else
{
assert(false);
}
}
}
}
}
}
};
......@@ -116,6 +116,13 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
}
}
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
{
return MatrixIndex{batch_in_c, m_in_c, n_in_c};
}
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
......@@ -219,6 +226,306 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
}
};
template <unsigned BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop,
unsigned BatchPerThread>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned batch;
unsigned row;
unsigned col;
};
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
{
static_assert(BatchSize % BatchPerThread == 0,
"wrong! BatchSize is not dividable by BatchPerThread");
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
"wrong! wrong blocksize\n");
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
"wrong! Cannot evenly divide work among Level0Cluster\n");
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
"wrong! thread work size is wrong\n");
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
{
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster;
unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
unsigned level1_id = cluster_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = cluster_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{batch_work_id * BatchPerThread,
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{batch_in_c,
m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread,
Accumulator f_accum) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
// loop over k
#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
}
// copy B-sub to form B
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
}
// loop over batch
#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + ib * ThreadMatrixStrideC,
f_accum);
// read next batch of a, b
if(BlockMatrixStrideA != 0)
{
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
p_a_block +
a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
}
}
if(BlockMatrixStrideB != 0)
{
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block +
b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
}
}
}
// do last batch of gemm
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC,
f_accum);
}
}
};
template <unsigned BlockSize,
class BlockMatrixA,
class BlockMatrixB,
......@@ -588,7 +895,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
......
......@@ -63,3 +63,20 @@ struct Sequence
assert(false);
}
};
template <typename T>
__host__ __device__ constexpr T max(T a, T b)
{
return a > b ? a : b;
}
template <typename T>
__host__ __device__ constexpr T min(T a, T b)
{
return a < b ? a : b;
}
__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b)
{
return (a + b - 1) / b;
}
......@@ -20,13 +20,20 @@ template <unsigned GridSize,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
class InBlockCopyThreadPerDims,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
unsigned WeiBlockCopyDataPerRead,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop>
__global__ void
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
......@@ -114,12 +121,22 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
// blockwise copy
// input: format is [C, Hi, Wi, N]
#if 0
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths())>{};
#elif 1
const auto blockwise_in_copy = Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()),
InBlockCopyThreadPerDims,
InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock]
......@@ -131,7 +148,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
......@@ -164,6 +181,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{});
#if 0
const auto blockwise_batch_gemm =
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
......@@ -177,8 +195,27 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
out_hkwn_thread_desc.GetStride(I0),
HoPerBlock,
HoPerThread,
CPerThread,
GemmKPerThreadLoop,
true>{};
#else
const auto blockwise_batch_gemm = BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
0,
in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(I0),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
#endif
// LDS: be careful of alignment
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
......@@ -210,10 +247,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
p_wei_global_block_begin += CPerBlock * wei_csrk_global_desc.GetStride(I0),
__syncthreads())
{
// input: global mem to LDS,
// input: global mem to LDS
blockwise_in_copy.Run(p_in_global_block_begin, p_in_block);
// weight: global mem to LDS,
// weight: global mem to LDS
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
__syncthreads();
......@@ -223,34 +260,26 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
p_out_thread,
f_accum);
[](auto& acc, const auto&& v) { acc += v; });
}
}
}
const auto matrix_c_index =
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
#if 0
// for v1 batch-gemm
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock;
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
......@@ -263,4 +292,36 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn);
#else
for(unsigned ho = 0; ho < out_hkwn_thread_desc.GetLength(I0); ++ho)
{
for(unsigned k = 0; k < out_hkwn_thread_desc.GetLength(I1); ++k)
{
for(unsigned wo = 0; wo < out_hkwn_thread_desc.GetLength(I2); ++wo)
{
for(unsigned n = 0; n < out_hkwn_thread_desc.GetLength(I3); ++n)
{
const unsigned b = out_hkwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const unsigned ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const unsigned wo_thread = b_thread / NPerBlock;
const unsigned n_thread = b_thread - NPerBlock * wo_thread;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_hkwn_thread_desc.Get1dIndex(ho, k, wo, n)];
}
}
}
}
#endif
}
......@@ -259,16 +259,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
__syncthreads();
// load next data
#if 0
#if 1
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
#elif 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
#elif 1
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
......@@ -300,8 +293,6 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
}
#if 0
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
#elif 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment