Commit 796f72e2 authored by Chao Liu's avatar Chao Liu
Browse files

load smaller weight tensor

parent 5b36aead
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "gridwise_convolution_wrapper.hip.hpp" #include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...@@ -179,32 +180,42 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -179,32 +180,42 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 1
// for 7x7, 38x38 // for 7x7, 38x38
constexpr index_t NPerBlock = 8; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 1; constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4; constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4; constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4; constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1; constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 2;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// for 3x3, 56x56 // for 3x3, 56x56
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 4; constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2; constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2; constexpr index_t WoPerBlock = 2;
...@@ -216,19 +227,19 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -216,19 +227,19 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2; constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4; constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8; constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 0
...@@ -264,7 +275,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -264,7 +275,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 0 #elif 1
// for 1x1, 14x14, Pascal // for 1x1, 14x14, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -306,9 +317,11 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -306,9 +317,11 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 1 #if 0
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
#else #elif 1
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
#endif #endif
<GridSize, <GridSize,
......
...@@ -421,13 +421,13 @@ int main(int argc, char* argv[]) ...@@ -421,13 +421,13 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3, 56x56 // 3x3, 56x56
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 64; constexpr index_t C = 64;
constexpr index_t HI = 56; constexpr index_t HI = 56;
constexpr index_t WI = 56; constexpr index_t WI = 56;
constexpr index_t K = 64; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -454,13 +454,13 @@ int main(int argc, char* argv[]) ...@@ -454,13 +454,13 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 7x7, 38x38 // 7x7, 38x38
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 38; constexpr index_t HI = 38;
constexpr index_t WI = 38; constexpr index_t WI = 38;
constexpr index_t K = 64; constexpr index_t K = 128;
constexpr index_t Y = 7; constexpr index_t Y = 7;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -644,6 +644,9 @@ int main(int argc, char* argv[]) ...@@ -644,6 +644,9 @@ int main(int argc, char* argv[])
#if 0 #if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
...@@ -666,7 +669,7 @@ int main(int argc, char* argv[]) ...@@ -666,7 +669,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1 #elif 0
device_implicit_gemm_convolution_2_chwn_cyxk_khwn device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif #endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
......
...@@ -95,15 +95,18 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -95,15 +95,18 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
// tensor view of blockwise input and weight in LDS // tensor view of blockwise input and weight in LDS
// be careful of alignment // be careful of alignment
constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{},
Number<InBlockCopyDataPerRead>{}); Number<max_align>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{}); Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<max_align>{});
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{}); Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<max_align>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor(
...@@ -147,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -147,7 +150,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr auto c_kxwn_thread_mtx_desc = constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{}, Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I1)>{}); Number<out_khwn_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm = const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
...@@ -169,9 +172,6 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -169,9 +172,6 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
HoPerThread>{}; HoPerThread>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = constexpr index_t wei_block_space =
......
...@@ -98,15 +98,18 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer ...@@ -98,15 +98,18 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
// tensor view of blockwise input and weight in LDS // tensor view of blockwise input and weight in LDS
// be careful of alignment // be careful of alignment
constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{},
Number<InBlockCopyDataPerRead>{}); Number<max_align>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{}); Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<max_align>{});
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{}); Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<max_align>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor(
...@@ -150,7 +153,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer ...@@ -150,7 +153,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
constexpr auto c_kxwn_thread_mtx_desc = constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{}, Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I1)>{}); Number<out_khwn_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm = const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
...@@ -172,9 +175,6 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer ...@@ -172,9 +175,6 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
HoPerThread>{}; HoPerThread>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = constexpr index_t wei_block_space =
......
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// be careful of this assertion
static_assert(NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// 2d tensor view of gridwise weight
constexpr auto wei_ck_global_desc = make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
constexpr auto wei_ck_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
// tensor view of threadwise output in register
constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()),
InBlockCopyThreadPerDims,
InBlockCopyDataPerRead>{};
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const auto blockwise_wei_copy =
#if 0//debug
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ck_global_desc),
decltype(wei_ck_block_desc),
decltype(wei_ck_block_desc.GetLengths())>{};
#else
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ck_global_desc),
decltype(wei_ck_block_desc),
decltype(wei_ck_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_ck_block_desc.GetStride(I0)>{});
constexpr auto b_cxwn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_chwn_block_desc.GetStride(I0)>{});
constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
0,
in_chwn_block_desc.GetStride(I1),
out_khwn_thread_desc.GetStride(I1),
HoPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
// LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_ck_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
#if 1
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
print_ConstantTensorDescriptor(wei_cyxk_global_desc, "wei_cyxk_global_desc");
print_ConstantTensorDescriptor(wei_ck_global_desc, "wei_ck_global_desc");
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_ck_block_desc, "wei_ck_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
const Float* p_in_global_block_offset =
p_in_global +
in_chwn_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0))
{
// input: global mem to LDS
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
// weight: global mem to LDS
blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_cyxk_global_desc.Get1dIndex(0, y, x, 0),
p_wei_block);
__syncthreads();
blockwise_batch_gemm.Run(p_wei_block,
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
p_out_thread);
__syncthreads();
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif 1
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin =
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2,
N / (N1 * N2),
N1,
N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{});
#endif
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment