Unverified Commit 1a66e35b authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

MIopen integration (#13)

* update for miopen integration: cosmetic refactor
parent 3406a114
...@@ -114,10 +114,10 @@ struct GridwiseCol2Im_eb_nchw ...@@ -114,10 +114,10 @@ struct GridwiseCol2Im_eb_nchw
1, 1,
BlockCopyDataPerAccess_B, BlockCopyDataPerAccess_B,
BlockCopyDataPerAccess_B, BlockCopyDataPerAccess_B,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::global, AddressSpace::Global,
InMemoryDataOperation::atomic_add>( InMemoryDataOperation::AtomicAdd>(
{e_block_data_on_global, b_block_data_on_global}, {e_block_data_on_global, b_block_data_on_global},
{e_block_data_on_global, b_block_data_on_global}); {e_block_data_on_global, b_block_data_on_global});
......
...@@ -25,15 +25,15 @@ template <index_t GridSize, ...@@ -25,15 +25,15 @@ template <index_t GridSize,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t ThreadGemmAThreadCopySrcDataPerRead_GemmM,
index_t GemmThreadGemmDataPerReadM, index_t ThreadGemmAThreadCopySrcDataPerRead_GemmN,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM, typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM, typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmN, index_t GemmABlockCopySrcDataPerRead_GemmN,
...@@ -75,25 +75,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -75,25 +75,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load //\todo static_assert for global vector load/store
// TODO: this logic may not be correct for bwd-data // statc_assert();
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && // weight tensor
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), constexpr auto wei_gemmk_gemmm_global_desc =
"wrong! aligment requirement for vectorized global load of input tensor will " unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
"be violated");
// output tensor // output tensor
constexpr auto out_k_b_global_desc = constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}), make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// weight tensor
constexpr auto wei_k_e_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
// input tensor // input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
...@@ -116,38 +111,42 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -116,38 +111,42 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor( constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}), make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM // GEMM
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) // \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
? InMemoryDataOperation::none // atomic, find out all of them
: InMemoryDataOperation::atomic_add; constexpr bool not_need_atomic = (ConvStrideH >= ConvDilationH * (Y - 1) + 1) and
(ConvStrideW >= ConvDilationW * (X - 1) + 1);
constexpr auto in_memory_op =
not_need_atomic ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd;
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize, GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
decltype(wei_k_e_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(out_k_b_global_desc), decltype(out_gemmk_gemmn_global_desc),
decltype(in_e_b_global_desc), decltype(in_gemmm_gemmn_global_desc),
in_memory_op, in_memory_op,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, ThreadGemmAThreadCopySrcDataPerRead_GemmM,
GemmThreadGemmDataPerReadM, ThreadGemmAThreadCopySrcDataPerRead_GemmN,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>, Sequence<0, 1>,
......
...@@ -147,10 +147,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -147,10 +147,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
2, 2,
OutBlockCopySrcDataPerRead_B, OutBlockCopySrcDataPerRead_B,
OutBlockCopyDstDataPerWrite_N0, OutBlockCopyDstDataPerWrite_N0,
AddressSpace::global, AddressSpace::Global,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::lds, AddressSpace::Lds,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
{0, b_block_data_on_global, 0}, {0, 0, 0}); {0, b_block_data_on_global, 0}, {0, 0, 0});
// weight tensor // weight tensor
...@@ -187,10 +187,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -187,10 +187,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
2, 2,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_C0, WeiBlockCopyDstDataPerWrite_C0,
AddressSpace::global, AddressSpace::Global,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::lds, AddressSpace::Lds,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
{0, e_block_data_on_global, 0}, {0, 0, 0}); {0, e_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition // GEMM definition
...@@ -356,10 +356,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -356,10 +356,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
#if 1 // debug #if 1 // debug
// input: register to global memory, atomic add // input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none ? InMemoryDataOperation::Set
: InMemoryDataOperation::atomic_add; : InMemoryDataOperation::AtomicAdd;
#else #else
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add; constexpr auto in_memory_op = InMemoryDataOperation::AtomicAdd;
#endif #endif
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
...@@ -432,8 +432,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -432,8 +432,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
4, 4,
1, 1,
InThreadCopyDstDataPerWrite_B, InThreadCopyDstDataPerWrite_B,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::global, AddressSpace::Global,
in_memory_op>({0, 0, 0, 0, 0, 0}, in_memory_op>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1, {e_thread_data_on_global / E1,
e_thread_data_on_global % E1, e_thread_data_on_global % E1,
......
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
namespace ck { namespace ck {
// GemmM = C * Ytilda * Xtilda; // GemmM = C * YTilda * XTilda;
// GemmN = N * HtildaNonZero * WtildaNonZero; // GemmN = N * HTildaSlice * WTildaSlice;
// GemmK = K * Ydot * Xdot; // GemmK = K * YDot * XDot;
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -25,13 +25,13 @@ template <index_t GridSize, ...@@ -25,13 +25,13 @@ template <index_t GridSize,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM, index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN, index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM, typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
...@@ -81,32 +81,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -81,32 +81,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"be violated"); "be violated");
#endif #endif
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t Htilda = constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor( constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor( constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min( constexpr index_t HTildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min( constexpr index_t WTildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
// weight tensor // weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
...@@ -114,17 +114,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -114,17 +114,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Y, Embed<Y,
Sequence<Ydot, Ytilda>, Sequence<YDot, YTilda>,
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>>{}, Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>>{},
Embed<X, Embed<X,
Sequence<Xdot, Xtilda>, Sequence<XDot, XTilda>,
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>>{}), Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<C, Ytilda, Xtilda>>{}), make_tuple(Merge<Sequence<K, YDot, XDot>>{}, Merge<Sequence<C, YTilda, XTilda>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -134,33 +134,33 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -134,33 +134,33 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
Embed<Ho, Embed<Ho,
Sequence<Ydot, Htilda>, Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{}, Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>>{},
Embed<Wo, Embed<Wo,
Sequence<Xdot, Wtilda>, Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}), Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc, out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
PassThrough<Ytilda>{}, PassThrough<YTilda>{},
PassThrough<Xtilda>{}, PassThrough<XTilda>{},
Slice<Sequence<Htilda, Wtilda>, Slice<Sequence<HTilda, WTilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HTildaLeft, WTildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}), Sequence<HTildaRight, WTildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, transform_tensor_descriptor(out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, make_tuple(Merge<Sequence<K, YDot, XDot>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}), Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -188,35 +188,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -188,35 +188,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Hip, Embed<Hip,
Sequence<Ytilda, Htilda>, Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>, Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{}, in_skip_all_out_of_bound_check>{},
Embed<Wip, Embed<Wip,
Sequence<Xtilda, Wtilda>, Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>, Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}), in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
PassThrough<Ytilda>{}, PassThrough<YTilda>{},
PassThrough<Xtilda>{}, PassThrough<XTilda>{},
Slice<Sequence<Htilda, Wtilda>, Slice<Sequence<HTilda, WTilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HTildaLeft, WTildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}), Sequence<HTildaRight, WTildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = constexpr auto in_gemmm_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, transform_tensor_descriptor(in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, make_tuple(Merge<Sequence<C, YTilda, XTilda>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}), Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -229,17 +229,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -229,17 +229,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc), decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc), decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none, InMemoryDataOperation::Set,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
namespace ck { namespace ck {
// Ytilda*Xtilda number of GEMMs // Number of GEMMs: YTilda * XTilda
// GemmM = C; // GemmM = C
// GemmN = N * HtildaNonZero * WtildaNonZero; // GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YdotNonZero * XdotNonZero; // GemmK = K * YDotSlice * XDotSlice
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -26,13 +26,13 @@ template <index_t GridSize, ...@@ -26,13 +26,13 @@ template <index_t GridSize,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM, index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN, index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM, typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
...@@ -110,32 +110,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -110,32 +110,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
"be violated"); "be violated");
#endif #endif
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t Htilda = constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda = constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor( constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor( constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min( constexpr index_t HTildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min( constexpr index_t WTildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true; constexpr bool wei_skip_all_out_of_bound_check = true;
...@@ -145,12 +145,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -145,12 +145,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Y, Embed<Y,
Sequence<Ydot, Ytilda>, Sequence<YDot, YTilda>,
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>, Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
wei_skip_all_out_of_bound_check>{}, wei_skip_all_out_of_bound_check>{},
Embed<X, Embed<X,
Sequence<Xdot, Xtilda>, Sequence<XDot, XTilda>,
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>, Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
wei_skip_all_out_of_bound_check>{}), wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -167,26 +167,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -167,26 +167,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
Embed<Ho, Embed<Ho,
Sequence<Ydot, Htilda>, Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>, Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
out_skip_all_out_of_bound_check>{}, out_skip_all_out_of_bound_check>{},
Embed<Wo, Embed<Wo,
Sequence<Xdot, Wtilda>, Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>, Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
out_skip_all_out_of_bound_check>{}), out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc, out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
PassThrough<Ytilda>{}, PassThrough<YTilda>{},
PassThrough<Xtilda>{}, PassThrough<XTilda>{},
Slice<Sequence<Htilda, Wtilda>, Slice<Sequence<HTilda, WTilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HTildaLeft, WTildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}), Sequence<HTildaRight, WTildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -216,26 +216,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -216,26 +216,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Hip, Embed<Hip,
Sequence<Ytilda, Htilda>, Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>, Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{}, in_skip_all_out_of_bound_check>{},
Embed<Wip, Embed<Wip,
Sequence<Xtilda, Wtilda>, Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>, Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}), in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
PassThrough<Ytilda>{}, PassThrough<YTilda>{},
PassThrough<Xtilda>{}, PassThrough<XTilda>{},
Slice<Sequence<Htilda, Wtilda>, Slice<Sequence<HTilda, WTilda>,
Sequence<HtildaLeft, WtildaLeft>, Sequence<HTildaLeft, WTildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}), Sequence<HTildaRight, WTildaRight>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
...@@ -246,54 +246,49 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -246,54 +246,49 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
__shared__ Float p_shared_block[shared_block_size]; __shared__ Float p_shared_block[shared_block_size];
#if 1 // debug static_for<0, YTilda, 1>{}([&](auto iYTilda_) {
static_for<0, Ytilda, 1>{}([&](auto ytilda_) { static_for<0, XTilda, 1>{}([&](auto iXTilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) { constexpr index_t iYTilda = decltype(iYTilda_){};
#else constexpr index_t iXTilda = decltype(iXTilda_){};
static_for<0, 1, 1>{}([&](auto ytilda_) {
static_for<0, 1, 1>{}([&](auto xtilda_) {
#endif
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot; constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot; constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
// A matrix // A matrix
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc = constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{}, make_tuple(PassThrough<K>{},
PassThrough<C>{}, PassThrough<C>{},
Slice<Sequence<Ydot, Xdot>, Slice<Sequence<YDot, XDot>,
Sequence<0, 0>, Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{}, Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<Ytilda, Xtilda>, Slice<Sequence<YTilda, XTilda>,
Sequence<ytilda, xtilda>, Sequence<iYTilda, iXTilda>,
Sequence<ytilda + 1, xtilda + 1>>{}), Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc, wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{}, make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<C, 1, 1>>{}), Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix // B matrix
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc = constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
PassThrough<HtildaTrim>{}, PassThrough<HTildaSlice>{},
PassThrough<WtildaTrim>{}, PassThrough<WTildaSlice>{},
Slice<Sequence<Ydot, Xdot>, Slice<Sequence<YDot, XDot>,
Sequence<0, 0>, Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{}), Sequence<YDotSlice, XDotSlice>>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -306,23 +301,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -306,23 +301,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Sequence<2, 4>{})); Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc, out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{}, make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}), Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix // C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
PassThrough<HtildaTrim>{}, PassThrough<HTildaSlice>{},
PassThrough<WtildaTrim>{}, PassThrough<WTildaSlice>{},
Slice<Sequence<Ytilda, Xtilda>, Slice<Sequence<YTilda, XTilda>,
Sequence<ytilda, xtilda>, Sequence<iYTilda, iXTilda>,
Sequence<ytilda + 1, xtilda + 1>>{}), Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -335,9 +330,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -335,9 +330,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Sequence<2, 4>{})); Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc, in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{}, make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}), Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -349,17 +344,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -349,17 +344,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc), decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc), decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none, InMemoryDataOperation::Set,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN, GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
...@@ -229,10 +229,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -229,10 +229,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
3, 3,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2, InBlockCopyDstDataPerWrite_N2,
AddressSpace::global, AddressSpace::Global,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::lds, AddressSpace::Lds,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
...@@ -269,10 +269,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -269,10 +269,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K, WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global, AddressSpace::Global,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::lds, AddressSpace::Lds,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -344,6 +344,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -344,6 +344,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
} }
constexpr auto in_block_slice_copy_steps = Sequence<EPerBlock, 0, 0, 0>{};
constexpr auto wei_block_slice_copy_steps = Sequence<EPerBlock, 0>{};
// LDS double buffer: main body // LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock) e_block_data_begin += 2 * EPerBlock)
...@@ -366,8 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -366,8 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
__syncthreads(); __syncthreads();
...@@ -393,8 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -393,8 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
__syncthreads(); __syncthreads();
...@@ -482,14 +485,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -482,14 +485,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
3, 3,
1, 1,
1, 1,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::global, AddressSpace::Global,
InMemoryDataOperation::none>({0, 0, 0, 0, 0}, InMemoryDataOperation::Set>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1, {k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
0, 0,
b_thread_data_on_global, b_thread_data_on_global,
0}) 0})
.Run(p_out_thread, p_out_global); .Run(p_out_thread, p_out_global);
} }
} }
......
...@@ -94,9 +94,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep ...@@ -94,9 +94,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space = constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{}; integral_constant<AddressSpace, AddressSpace::Generic>{};
constexpr auto global_address_space = constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{}; integral_constant<AddressSpace, AddressSpace::Global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward || static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight, ConvDirection == ConvolutionDirection::BackwardWeight,
...@@ -141,13 +141,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep ...@@ -141,13 +141,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) && static_assert(
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), (Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
"wrong! aligment requirement for vectorized global load of input tensor will " (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"be violated"); "wrong! alignment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block"); "wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock; constexpr index_t KBlockWork = K / KPerBlock;
...@@ -357,37 +358,49 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep ...@@ -357,37 +358,49 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
// LDS double buffer: tail // LDS double buffer: tail
{ {
// even iteration constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); if(has_two_iteration_left) // if has 2 iteration left
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); {
// even iteration
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
__syncthreads(); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
// LDS doubel buffer: load next data from device mem __syncthreads();
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data // LDS doubel buffer: load next data from device mem
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: store next data to LDS // LDS double buffer: GEMM on current data
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, // LDS double buffer: store next data to LDS
p_wei_block_double + wei_block_space); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// odd iteration // LDS double buffer: GEMM on current data
__syncthreads(); blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double + wei_block_space, blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
p_in_block_double + in_block_space, }
p_out_thread);
} }
// copy output: register to global memory // copy output: register to global memory
......
...@@ -25,15 +25,15 @@ template <index_t GridSize, ...@@ -25,15 +25,15 @@ template <index_t GridSize,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t ThreadGemmAThreadCopySrcDataPerRead_GemmM,
index_t GemmThreadGemmDataPerReadM, index_t ThreadGemmAThreadCopySrcDataPerRead_GemmN,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM, typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM, typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK, index_t GemmABlockCopySrcDataPerRead_GemmK,
...@@ -130,19 +130,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -130,19 +130,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(out_k_b_global_desc), decltype(out_k_b_global_desc),
InMemoryDataOperation::none, InMemoryDataOperation::Set,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, ThreadGemmAThreadCopySrcDataPerRead_GemmM,
GemmThreadGemmDataPerReadM, ThreadGemmAThreadCopySrcDataPerRead_GemmN,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM, GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
......
...@@ -251,9 +251,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep ...@@ -251,9 +251,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.template Run<Float, AddressSpace::global>(p_in_global, blockwise_in_copy.template Run<Float, AddressSpace::Global>(p_in_global,
p_in_block_double); p_in_block_double);
blockwise_wei_copy.template Run<Float, AddressSpace::global>(p_wei_global, blockwise_wei_copy.template Run<Float, AddressSpace::Global>(p_wei_global,
p_wei_block_double); p_wei_block_double);
} }
...@@ -285,9 +285,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep ...@@ -285,9 +285,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>( blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>( blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
...@@ -311,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep ...@@ -311,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>( blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>( blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
...@@ -390,7 +390,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep ...@@ -390,7 +390,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{ {
threadwise_out_copy threadwise_out_copy
.template Run<Float, AddressSpace::generic, AddressSpace::global>(p_out_thread, .template Run<Float, AddressSpace::Generic, AddressSpace::Global>(p_out_thread,
p_out_global); p_out_global);
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True); threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
......
...@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto ...@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>) make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
{ {
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>; using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
......
...@@ -267,7 +267,7 @@ struct TensorCoordinate ...@@ -267,7 +267,7 @@ struct TensorCoordinate
private: private:
template <typename... Ts> template <typename... Ts>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>) MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
{ {
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>( return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>()); make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
...@@ -275,7 +275,7 @@ struct TensorCoordinate ...@@ -275,7 +275,7 @@ struct TensorCoordinate
template <typename... Ts> template <typename... Ts>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>) MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
{ {
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>( return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>()); make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
......
...@@ -327,14 +327,14 @@ struct TensorCoordinate_deprecated ...@@ -327,14 +327,14 @@ struct TensorCoordinate_deprecated
private: private:
template <class... Ts> template <class... Ts>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>) MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
{ {
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>(); return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
} }
template <class... Ts> template <class... Ts>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>) MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
{ {
return MergedTensorCoordinate_deprecated< return MergedTensorCoordinate_deprecated<
ConstantMergedTensorDescriptor_deprecated<Ts...>>(); ConstantMergedTensorDescriptor_deprecated<Ts...>>();
......
...@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor, ...@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor,
index_t... LowerDimensionIds, index_t... LowerDimensionIds,
index_t... UpperDimensionIds> index_t... UpperDimensionIds>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>, Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>, Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>) Sequence<UpperDimensionIds...>)
{ {
return TransformedTensorDescriptor<LowerTensorDescriptor, return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>, Tuple<PassThrough<LowerLengths>...>,
...@@ -78,7 +78,7 @@ reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, ...@@ -78,7 +78,7 @@ reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
// reorder a NativeTensorDescriptor // reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper> template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper) reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{ {
static_assert(is_valid_sequence_map<MapLower2Upper>{}, static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map"); "wrong! MapLower2Upper is not a valid map");
...@@ -96,7 +96,7 @@ reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLo ...@@ -96,7 +96,7 @@ reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLo
// reorder a TransformedTensorDescriptor // reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper> template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper) reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{ {
static_assert(is_valid_sequence_map<MapLower2Upper>{}, static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map"); "wrong! MapLower2Upper is not a valid map");
...@@ -152,9 +152,9 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript ...@@ -152,9 +152,9 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{}; typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{}; constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
// sanity-checknfoldable // sanity-check if unfold-able
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)), static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
"wrong! not unfoldable"); "wrong! not unfold-able");
// unfolded length, stride // unfolded length, stride
constexpr index_t unfold_length = constexpr index_t unfold_length =
......
...@@ -23,8 +23,8 @@ template <index_t BlockSize, ...@@ -23,8 +23,8 @@ template <index_t BlockSize,
index_t MLevel1ThreadCluster, index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster, index_t NLevel1ThreadCluster,
index_t KPerThreadLoop, index_t KPerThreadLoop,
index_t DataPerReadA, index_t ThreadGemmADataPerRead_M,
index_t DataPerReadB> index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{ {
struct MatrixIndex struct MatrixIndex
...@@ -150,13 +150,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -150,13 +150,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
decltype(a_thread_mtx), decltype(a_thread_mtx),
KPerThreadLoop, KPerThreadLoop,
MPerThreadSubC, MPerThreadSubC,
DataPerReadA>{}; ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB, constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx), decltype(b_thread_mtx),
KPerThreadLoop, KPerThreadLoop,
NPerThreadSubC, NPerThreadSubC,
DataPerReadB>{}; ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx), ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
...@@ -238,13 +238,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -238,13 +238,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
decltype(a_thread_mtx), decltype(a_thread_mtx),
KPerThreadLoop, KPerThreadLoop,
MPerThreadSubC, MPerThreadSubC,
DataPerReadA>{}; ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB, constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx), decltype(b_thread_mtx),
KPerThreadLoop, KPerThreadLoop,
NPerThreadSubC, NPerThreadSubC,
DataPerReadB>{}; ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx), ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace ck { namespace ck {
// This threadwise copy allow vector access of src and dst. // This blockwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst. // It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst. // The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst. // The dimension access order can be different for src and dst.
...@@ -28,10 +28,10 @@ template <index_t BlockSize, ...@@ -28,10 +28,10 @@ template <index_t BlockSize,
index_t DstVectorWriteDim, index_t DstVectorWriteDim,
index_t SrcDataPerRead, index_t SrcDataPerRead,
index_t DstDataPerWrite, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic, AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic, AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::generic, AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none> InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
struct BlockwiseGenericTensorSliceCopy_v4 struct BlockwiseGenericTensorSliceCopy_v4
{ {
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension(); static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
...@@ -115,7 +115,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -115,7 +115,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
template <typename BlockSrcData, typename BlockDstData> template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{ {
static_assert(ThreadBufferAddressSpace == AddressSpace::vgpr, static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
"wrong! This function use vgpr as its thread " "wrong! This function use vgpr as its thread "
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer " "buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. " "to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
...@@ -157,7 +157,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -157,7 +157,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
1, 1,
SrcAddressSpace, SrcAddressSpace,
ThreadBufferAddressSpace, ThreadBufferAddressSpace,
InMemoryDataOperation::none>; InMemoryDataOperation::Set>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc, using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
BlockDstDesc, BlockDstDesc,
......
...@@ -499,7 +499,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated ...@@ -499,7 +499,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
ThreadBufferData* p_thread_buffer) const ThreadBufferData* p_thread_buffer) const
{ {
constexpr auto generic_address_space = constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{}; integral_constant<AddressSpace, AddressSpace::Generic>{};
RunLoadThreadBuffer( RunLoadThreadBuffer(
p_block_src, p_thread_buffer, generic_address_space, generic_address_space); p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
...@@ -529,7 +529,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated ...@@ -529,7 +529,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
BlockDstData* p_block_dst) const BlockDstData* p_block_dst) const
{ {
constexpr auto generic_address_space = constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{}; integral_constant<AddressSpace, AddressSpace::Generic>{};
RunStoreThreadBuffer( RunStoreThreadBuffer(
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space); p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
...@@ -548,7 +548,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated ...@@ -548,7 +548,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
BlockSrcData p_thread_buffer[GetThreadBufferSize()]; BlockSrcData p_thread_buffer[GetThreadBufferSize()];
constexpr auto generic_address_space = constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{}; integral_constant<AddressSpace, AddressSpace::Generic>{};
RunLoadThreadBuffer( RunLoadThreadBuffer(
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space); p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
...@@ -562,7 +562,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated ...@@ -562,7 +562,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{ {
constexpr auto generic_address_space = constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{}; integral_constant<AddressSpace, AddressSpace::Generic>{};
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space); Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
} }
......
...@@ -22,15 +22,15 @@ template <index_t GridSize, ...@@ -22,15 +22,15 @@ template <index_t GridSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t MPerThreadSubC, index_t MPerThread,
index_t NPerThreadSubC, index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster, index_t MLevel0Cluster,
index_t NLevel0Cluster, index_t NLevel0Cluster,
index_t MLevel1Cluster, index_t MLevel1Cluster,
index_t NLevel1Cluster, index_t NLevel1Cluster,
index_t KPerThreadLoop, index_t ThreadGemmAThreadCopySrcDataPerRead_M,
index_t ThreadGemmDataPerReadM, index_t ThreadGemmBThreadCopySrcDataPerRead_N,
index_t ThreadGemmDataPerReadN,
typename ABlockCopyThreadSliceLengths_K_M, typename ABlockCopyThreadSliceLengths_K_M,
typename ABlockCopyThreadClusterLengths_K_M, typename ABlockCopyThreadClusterLengths_K_M,
typename ABlockCopyThreadClusterArrangeOrder, typename ABlockCopyThreadClusterArrangeOrder,
...@@ -54,8 +54,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -54,8 +54,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N, BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM, ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmDataPerReadN); ThreadGemmBThreadCopySrcDataPerRead_N);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -101,8 +101,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -101,8 +101,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// lds max alignment // lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N, BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM, ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmDataPerReadN); ThreadGemmBThreadCopySrcDataPerRead_N);
// divide block work by [M, N] // divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0, static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
...@@ -139,10 +139,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -139,10 +139,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
1, 1,
ABlockCopySrcDataPerRead, ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M, ABlockCopyDstDataPerWrite_M,
AddressSpace::global, AddressSpace::Global,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::lds, AddressSpace::Lds,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0}); {0, m_block_data_on_global}, {0, 0});
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
...@@ -165,10 +165,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -165,10 +165,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
1, 1,
BBlockCopySrcDataPerRead, BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N, BBlockCopyDstDataPerWrite_N,
AddressSpace::global, AddressSpace::Global,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::lds, AddressSpace::Lds,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0}); {0, n_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -181,35 +181,33 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -181,35 +181,33 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc); constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check // sanity check
static_assert(MPerBlock % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 && static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0, NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
MPerBlock / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
NPerBlock / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThreadSubC>{}, Number<GemmNRepeat * NPerThreadSubC>{}); Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_k_m_block_mtx_desc), decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc), decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc), decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThreadSubC, MPerThread,
NPerThreadSubC, NPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
KPerThreadLoop, KPerThread,
ThreadGemmDataPerReadM, ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmDataPerReadN>{}; ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space = constexpr index_t a_block_space =
...@@ -233,6 +231,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -233,6 +231,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
b_blockwise_copy.Run(p_b_global, p_b_block_double); b_blockwise_copy.Run(p_b_global, p_b_block_double);
} }
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
// LDS double buffer: main body // LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock) k_block_data_begin += 2 * KPerBlock)
...@@ -255,8 +256,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -255,8 +256,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads(); __syncthreads();
...@@ -282,8 +283,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -282,8 +283,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads(); __syncthreads();
...@@ -317,16 +318,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -317,16 +318,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// input: register to global memory // input: register to global memory
{ {
constexpr index_t M1 = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1; constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1; constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy // define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy // thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThreadSubC, GemmNRepeat, NPerThreadSubC>{}); Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor( constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc, c_m_n_global_desc,
...@@ -352,8 +353,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -352,8 +353,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
CThreadCopySrcDstVectorReadWriteDim, CThreadCopySrcDstVectorReadWriteDim,
1, 1,
CThreadCopyDstDataPerWrite, CThreadCopyDstDataPerWrite,
AddressSpace::vgpr, AddressSpace::Vgpr,
AddressSpace::global, AddressSpace::Global,
CGlobalMemoryDataOperation>( CGlobalMemoryDataOperation>(
{0, 0, 0, 0}, {0, 0, 0, 0},
{m_thread_data_on_global / M1, {m_thread_data_on_global / M1,
......
...@@ -21,9 +21,9 @@ template <typename SrcDesc, ...@@ -21,9 +21,9 @@ template <typename SrcDesc,
index_t SrcDstVectorReadWriteDim, index_t SrcDstVectorReadWriteDim,
index_t SrcDataPerRead, index_t SrcDataPerRead,
index_t DstDataPerWrite, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic, AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::generic, AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none> InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
struct ThreadwiseGenericTensorSliceCopy_v4r2 struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -115,8 +115,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -115,8 +115,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
transfer_data<SrcData, transfer_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
} }
} }
...@@ -146,7 +146,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -146,7 +146,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
transfer_data<DstData, transfer_data<DstData,
DstDataPerWrite, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::Vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
...@@ -265,12 +265,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -265,12 +265,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
transfer_data<SrcData, transfer_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::none>(p_src, InMemoryDataOperation::Set>(p_src,
src_nonlinear_coord.GetOffset() + src_nonlinear_coord.GetOffset() +
src_linear_offset, src_linear_offset,
p_src_long_vector, p_src_long_vector,
buffer_offset); buffer_offset);
} }
} }
...@@ -303,7 +303,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -303,7 +303,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
transfer_data<DstData, transfer_data<DstData,
DstDataPerWrite, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::Vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
...@@ -404,8 +404,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -404,8 +404,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
transfer_data<SrcData, transfer_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
} }
} }
...@@ -448,7 +448,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -448,7 +448,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
transfer_data<DstData, transfer_data<DstData,
DstDataPerWrite, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::Vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>(p_dst_long_vector, DstInMemOp>(p_dst_long_vector,
buffer_offset, buffer_offset,
......
...@@ -333,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated ...@@ -333,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
// 2. src_normal_offset must be calculatd at compile time (guaranteed by // 2. src_normal_offset must be calculatd at compile time (guaranteed by
// algorithm) // algorithm)
// 3. src_merged_offset can be runtime value (no assumption imposed) // 3. src_merged_offset can be runtime value (no assumption imposed)
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto fwd) { static_if<SrcAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
vector_data = amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>( vector_data = amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>(
fwd(p_src), src_merged_offset, src_normal_offset); fwd(p_src), src_merged_offset, src_normal_offset);
...@@ -442,7 +442,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated ...@@ -442,7 +442,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by // 2. dst_normal_offset must be calculatd at compile time (guaranteed by
// algorithm) // algorithm)
// 3. dst_merged_offset can be runtime value (no assumption imposed) // 3. dst_merged_offset can be runtime value (no assumption imposed)
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto fwd) { static_if<DstAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>( amd_intrinsic_buffer_store<DstData, DstDataPerAccess>(
vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset); vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset);
...@@ -464,7 +464,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated ...@@ -464,7 +464,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
__device__ void Run(const SrcData* p_src, DstData* p_dst) const __device__ void Run(const SrcData* p_src, DstData* p_dst) const
{ {
constexpr auto generic_address_space = constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{}; integral_constant<AddressSpace, AddressSpace::Generic>{};
Run(p_src, p_dst, generic_address_space, generic_address_space); Run(p_src, p_dst, generic_address_space, generic_address_space);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment