Commit 2eeeb176 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 08cbac98
...@@ -18,6 +18,8 @@ template <index_t GridSize, ...@@ -18,6 +18,8 @@ template <index_t GridSize,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -64,10 +66,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -64,10 +66,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{}; constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
...@@ -77,8 +76,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -77,8 +76,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
...@@ -87,6 +84,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -87,6 +84,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2); constexpr index_t N0 = N / (N1 * N2);
...@@ -95,6 +98,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -95,6 +98,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
"wrong! global vector load of input tensor is wrong");
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block"); "wrong! cannot divide work evenly among block");
...@@ -113,14 +124,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -113,14 +124,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Ho>{}) constexpr auto in_n0_n1_n2_h_w_global_desc =
.Slice(I3, Number<Wo>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I0, Number<N1>{}, Number<N2>{}) .Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{}) constexpr auto in_c_y_x_global_desc =
.Slice(I3, Number<X>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
...@@ -131,17 +144,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -131,17 +144,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
Sequence<3, 6, 7>{}, Sequence<3, 6, 7>{},
Sequence<5>{}); Sequence<5>{});
#if 0
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_n0_n1_n2_h_w_global_desc,
"in_n0_n1_n2_h_w_global_desc: ");
print_ConstantTensorDescriptor(in_c_y_x_global_desc, "in_c_y_x_global_desc: ");
print_ConstantMergedTensorDescriptor(in_e_n1_b_n2_global_merged_desc,
"in_e_n1_b_n2_global_merged_desc: ");
}
#endif
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
...@@ -206,13 +208,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -206,13 +208,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor( constexpr auto a_e_k_block_mtx_desc =
Number<EPerBlock>{}, Number<KPerBlock>{}, Number<wei_e_k_block_desc.GetStride(I0)>{}); make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<EPerBlock>{}, make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(
Number<N1 * BPerBlock * N2>{}, in_e_n1_b_n2_block_desc.Unfold(I1, I3));
Number<in_e_n1_b_n2_block_desc.GetStride(I0)>{});
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
...@@ -242,15 +243,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -242,15 +243,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1
return blockwise_gemm.Run(Xs...);
#else
return blockwise_gemm.Run_amd_asm(Xs...);
#endif
};
// LDS allocation for input and weight: be careful of alignment // LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K, WeiBlockCopyDstDataPerWrite_K,
...@@ -281,7 +273,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -281,7 +273,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
__syncthreads(); __syncthreads();
run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread); blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
__syncthreads(); __syncthreads();
...@@ -293,7 +285,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -293,7 +285,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
{ {
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / (K1 * K2);
// define tensor descriptor for threadwise copy // define tensor descriptor for threadwise copy
// output memory layout descriptor in register // output memory layout descriptor in register
......
...@@ -59,7 +59,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -59,7 +59,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t B = (N * Ho * Wo) / (N1 * N2); constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 0 #if 1
// each thread hold 64 data // each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -94,7 +94,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -94,7 +94,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// each thread hold 32 data // each thread hold 32 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment