Commit d2436a58 authored by Chao Liu's avatar Chao Liu
Browse files

v1r3 nchw*cyxk=nkhw lds double buffer

parent 63cdc6d2
......@@ -95,7 +95,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
......@@ -130,7 +130,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
......@@ -200,7 +200,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
......
......@@ -3,6 +3,7 @@
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
......@@ -92,7 +93,42 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<1, 2, 2, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 4, 2, 32>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 1;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 2;
......@@ -125,9 +161,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20, try
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 4;
......@@ -160,7 +196,77 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<2, 8, 4, 4>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<8, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
constexpr index_t BlockSize = 128;
......@@ -206,39 +312,44 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw<
GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_cyxk_desc),
decltype(out_nkhw_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_W>{};
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
#endif
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(wei_cyxk_desc),
decltype(out_nkhw_desc),
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_W>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize),
......
......@@ -371,7 +371,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t ho = HoPerTile * htile + j;
for(int i = 0; i < WoPerTile; ++i)
{
std::size_t wo = WoPerTile * wtile + i;
std::size_t wo = WoPerTile * wtile + i;
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
}
}
......@@ -413,13 +413,13 @@ int main(int argc, char* argv[])
{
#if 1
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
......
......@@ -74,22 +74,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
......@@ -185,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1
#if 0
return blockwise_batch_gemm.Run(Xs...);
#elif 0
return blockwise_batch_gemm.Run_asm(Xs...);
......
......@@ -81,22 +81,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
......
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_nd_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockReorderSrcSubLengths_NCHW,
class InBlockReorderSrcClusterLengths_NCHW,
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
index_t InBlockReorderDataPerRead_W,
index_t InBlockReorderDataPerWrite_N,
class WeiBlockCopyClusterLengths_CK, // not used
index_t WeiBlockCopyDataPerRead_K,
index_t OutThreadCopyDataPerWrite_W>
struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(
NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
// assert for LDS double buffer
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// LDS tensor view
// be careful of alignment
constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N,
WeiBlockCopyDataPerRead_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
Number<InBlockReorderDataPerWrite_N>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not meet");
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
const auto blockwise_in_copy_reorder =
BlockwiseNdTensorCopyReorder_v3<BlockSize,
Float,
decltype(in_n_c_h_w_global_desc),
decltype(in_c_h_w_n_block_desc),
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
decltype(map_chwn2nchw),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N>{};
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_c_k_block_mtx_desc),
decltype(b_c_wn_block_mtx_desc),
decltype(c_k_wn_thread_mtx_desc),
0,
in_c_h_w_n_block_desc.GetStride(I1),
out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 0
return blockwise_batch_gemm.Run(Xs...);
#elif 0
return blockwise_batch_gemm.Run_asm(Xs...);
#else
return blockwise_batch_gemm.Run_asm_v2(Xs...);
#endif
};
// LDS: be careful of alignment
constexpr index_t in_block_space =
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
// LDS double buffer
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register
// C++ lambda doesn't capture array, use pointer instead
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
Float* const p_out_thread = p_out_thread_data;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
const Float* p_in_global_block_offset =
p_in_global +
in_n_c_h_w_global_desc.Get1dIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin);
// LDS double buffer: preload data into LDS
{
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double);
}
// LDS double buffer: main body
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
c_block_data_begin += 2 * CPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
p_in_global_block_offset +=
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_clipboard[blockwise_in_copy_reorder
.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
// even iteration
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder.RunStoreRegisterClipboard(
p_in_register_clipboard, p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2),
N1,
N2,
K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{});
}).else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
// output is a 10d tensor
constexpr index_t N1 = NPerBlock;
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<N / N1, N1, K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{});
});
}
};
......@@ -193,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// choose GEMM implementation here
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
#if 1
#if 0
return blockwise_batch_gemm.Run(Xs...);
#elif 0
return blockwise_batch_gemm.Run_asm(Xs...);
......@@ -340,39 +340,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2),
N1,
N2,
K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2),
N1,
N2,
K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
......@@ -387,51 +388,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
}
#endif
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{});
})
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
// output is a 10d tensor
constexpr index_t N1 = NPerBlock;
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<N / N1,
N1,
K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{});
}).else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
// output is a 10d tensor
constexpr index_t N1 = NPerBlock;
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<N / N1, N1, K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
......@@ -447,21 +437,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
}
#endif
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{});
});
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{});
});
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment