"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8ccc76ab3760cdb1ab60c7a344e16f118bb58adc"
Commit 120ab94a authored by Chao Liu's avatar Chao Liu
Browse files

update with new copy op

parent 07f16673
...@@ -391,7 +391,7 @@ int main() ...@@ -391,7 +391,7 @@ int main()
constexpr unsigned HPad = 0; constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0; constexpr unsigned WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned N = 64; constexpr unsigned N = 64;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -587,11 +587,11 @@ int main() ...@@ -587,11 +587,11 @@ int main()
device_implicit_gemm_convolution_1_nchw_kcsr device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0 #elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 0 #elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0 #elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 1 #elif 0
device_implicit_gemm_convolution_2_cnhw_csrk_knhw device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#endif #endif
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
...@@ -608,7 +608,7 @@ int main() ...@@ -608,7 +608,7 @@ int main()
nrepeat); nrepeat);
#endif #endif
#if 1 #if 0
if(S == 3 && R == 3) if(S == 3 && R == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
......
...@@ -87,7 +87,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -87,7 +87,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8; constexpr unsigned BlockSize = 8;
#elif 1 #elif 0
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16; constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
...@@ -101,6 +101,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -101,6 +101,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned HoPerThread = 1; constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 16,256,128 // 3x3 58x58, NKC = 16,256,128
...@@ -162,7 +168,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -162,7 +168,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 1
// for 1x1, 28x28 // for 1x1, 28x28
constexpr unsigned NPerBlock = 16; constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128; constexpr unsigned KPerBlock = 128;
...@@ -176,6 +182,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -176,6 +182,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned HoPerThread = 1; constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#endif #endif
...@@ -211,7 +223,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -211,7 +223,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
KPerThread, KPerThread,
CPerThread, CPerThread,
HoPerThread, HoPerThread,
WoPerThread> WoPerThread,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>
<<<grid_dim, block_dim>>>(in_chwn_desc, <<<grid_dim, block_dim>>>(in_chwn_desc,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()), static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
wei_csrk_desc, wei_csrk_desc,
......
...@@ -108,6 +108,9 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, ...@@ -108,6 +108,9 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
constexpr unsigned HoPerThread = 1; constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 0
// 3x3 58x58, NKC = 16,256,128 // 3x3 58x58, NKC = 16,256,128
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "ConstantTensorDescriptor.cuh" #include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh" #include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh" #include "blockwise_4d_tensor_op.cuh"
#include "blockwise_2d_tensor_op.cuh"
#include "threadwise_4d_tensor_op.cuh" #include "threadwise_4d_tensor_op.cuh"
#include "blockwise_gemm.cuh" #include "blockwise_gemm.cuh"
...@@ -21,7 +22,11 @@ template <unsigned GridSize, ...@@ -21,7 +22,11 @@ template <unsigned GridSize,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread, unsigned CPerThread,
unsigned HoPerThread, unsigned HoPerThread,
unsigned WoPerThread> unsigned WoPerThread,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
Float* const __restrict__ p_in_global, Float* const __restrict__ p_in_global,
...@@ -80,12 +85,19 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, ...@@ -80,12 +85,19 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
const unsigned hi_block_data_begin = ho_block_data_begin; const unsigned hi_block_data_begin = ho_block_data_begin;
const unsigned wi_block_data_begin = wo_block_data_begin; const unsigned wi_block_data_begin = wo_block_data_begin;
// flattend (2d) tensor view of gridwise weight
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
// tensor view of blockwise input and weight in LDS // tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr auto in_chwn_block_desc = constexpr auto in_chwn_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}); make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
constexpr auto wei_csrk_block_desc = constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
make_ConstantTensorDescriptor(Sequence<CPerBlock, S, R, KPerBlock>{}); Sequence<CPerBlock * S * R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_hkwn_thread_desc = constexpr auto out_hkwn_thread_desc =
...@@ -112,13 +124,31 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, ...@@ -112,13 +124,31 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
decltype(in_chwn_block_desc), decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths())>{}; decltype(in_chwn_block_desc.GetLengths())>{};
// weight: format is [S,R,C,K] // blockwise wei copy
constexpr auto blockwise_wei_copy = // format is [CPerBlock*S*R,KPerBlock]
Blockwise4dTensorCopy1<BlockSize, #if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float, Float,
decltype(wei_csrk_global_desc), decltype(wei_ek_global_desc),
decltype(wei_csrk_block_desc), decltype(wei_ek_block_desc),
decltype(wei_csrk_block_desc.GetLengths())>{}; decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -155,12 +185,17 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, ...@@ -155,12 +185,17 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
CPerThread, CPerThread,
true>{}; true>{};
// LDS // LDS: be careful of alignment
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace(); constexpr unsigned wei_block_size =
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[in_block_size]; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[wei_block_size]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// register // register
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment