Unverified Commit 1a66e35b authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

MIopen integration (#13)

* update for miopen integration: cosmetic refactor
parent 3406a114
......@@ -114,10 +114,10 @@ struct GridwiseCol2Im_eb_nchw
1,
BlockCopyDataPerAccess_B,
BlockCopyDataPerAccess_B,
AddressSpace::vgpr,
AddressSpace::vgpr,
AddressSpace::global,
InMemoryDataOperation::atomic_add>(
AddressSpace::Vgpr,
AddressSpace::Vgpr,
AddressSpace::Global,
InMemoryDataOperation::AtomicAdd>(
{e_block_data_on_global, b_block_data_on_global},
{e_block_data_on_global, b_block_data_on_global});
......
......@@ -25,15 +25,15 @@ template <index_t GridSize,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmM,
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmN,
......@@ -75,25 +75,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
//\todo static_assert for global vector load/store
// statc_assert();
// weight tensor
constexpr auto wei_gemmk_gemmm_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
// output tensor
constexpr auto out_k_b_global_desc =
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// weight tensor
constexpr auto wei_k_e_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
......@@ -116,38 +111,42 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
// \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
// atomic, find out all of them
constexpr bool not_need_atomic = (ConvStrideH >= ConvDilationH * (Y - 1) + 1) and
(ConvStrideW >= ConvDilationW * (X - 1) + 1);
constexpr auto in_memory_op =
not_need_atomic ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd;
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc),
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
in_memory_op,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
ThreadGemmAThreadCopySrcDataPerRead_GemmM,
ThreadGemmAThreadCopySrcDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
......
......@@ -147,10 +147,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
2,
OutBlockCopySrcDataPerRead_B,
OutBlockCopyDstDataPerWrite_N0,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, b_block_data_on_global, 0}, {0, 0, 0});
// weight tensor
......@@ -187,10 +187,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
2,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_C0,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, e_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
......@@ -356,10 +356,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
#if 1 // debug
// input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
: InMemoryDataOperation::atomic_add;
? InMemoryDataOperation::Set
: InMemoryDataOperation::AtomicAdd;
#else
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add;
constexpr auto in_memory_op = InMemoryDataOperation::AtomicAdd;
#endif
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
......@@ -432,8 +432,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
4,
1,
InThreadCopyDstDataPerWrite_B,
AddressSpace::vgpr,
AddressSpace::global,
AddressSpace::Vgpr,
AddressSpace::Global,
in_memory_op>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1,
e_thread_data_on_global % E1,
......
......@@ -8,9 +8,9 @@
namespace ck {
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * Ydot * Xdot;
// GemmM = C * YTilda * XTilda;
// GemmN = N * HTildaSlice * WTildaSlice;
// GemmK = K * YDot * XDot;
template <index_t GridSize,
index_t BlockSize,
typename Float,
......@@ -25,13 +25,13 @@ template <index_t GridSize,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......@@ -81,32 +81,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"be violated");
#endif
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t Htilda =
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
// weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
......@@ -114,17 +114,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>>{},
Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>>{}),
Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<C, Ytilda, Xtilda>>{}),
make_tuple(Merge<Sequence<K, YDot, XDot>>{}, Merge<Sequence<C, YTilda, XTilda>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -134,33 +134,33 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{},
Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}),
Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
transform_tensor_descriptor(out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDot, XDot>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -188,35 +188,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
transform_tensor_descriptor(in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
make_tuple(Merge<Sequence<C, YTilda, XTilda>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -229,17 +229,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
......@@ -8,10 +8,10 @@
namespace ck {
// Ytilda*Xtilda number of GEMMs
// GemmM = C;
// GemmN = N * HtildaNonZero * WtildaNonZero;
// GemmK = K * YdotNonZero * XdotNonZero;
// Number of GEMMs: YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YDotSlice * XDotSlice
template <index_t GridSize,
index_t BlockSize,
typename Float,
......@@ -26,13 +26,13 @@ template <index_t GridSize,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......@@ -110,32 +110,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
"be violated");
#endif
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t Htilda =
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t Wtilda =
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t WTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true;
......@@ -145,12 +145,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -167,26 +167,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -216,26 +216,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<Ytilda, Htilda>,
Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_all_out_of_bound_check>{},
Embed<Wip,
Sequence<Xtilda, Wtilda>,
Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<Ytilda>{},
PassThrough<Xtilda>{},
Slice<Sequence<Htilda, Wtilda>,
Sequence<HtildaLeft, WtildaLeft>,
Sequence<HtildaRight, WtildaRight>>{}),
PassThrough<YTilda>{},
PassThrough<XTilda>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<HTildaLeft, WTildaLeft>,
Sequence<HTildaRight, WTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
......@@ -246,54 +246,49 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
__shared__ Float p_shared_block[shared_block_size];
#if 1 // debug
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
#else
static_for<0, 1, 1>{}([&](auto ytilda_) {
static_for<0, 1, 1>{}([&](auto xtilda_) {
#endif
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
static_for<0, YTilda, 1>{}([&](auto iYTilda_) {
static_for<0, XTilda, 1>{}([&](auto iXTilda_) {
constexpr index_t iYTilda = decltype(iYTilda_){};
constexpr index_t iXTilda = decltype(iXTilda_){};
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
// A matrix
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<Ydot, Xdot>,
Slice<Sequence<YDot, XDot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<YTilda, XTilda>,
Sequence<iYTilda, iXTilda>,
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<C, 1, 1>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ydot, Xdot>,
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YDot, XDot>,
Sequence<0, 0>,
Sequence<YdotNonZero, XdotNonZero>>{}),
Sequence<YDotSlice, XDotSlice>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......@@ -306,23 +301,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
transform_tensor_descriptor(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<HtildaTrim>{},
PassThrough<WtildaTrim>{},
Slice<Sequence<Ytilda, Xtilda>,
Sequence<ytilda, xtilda>,
Sequence<ytilda + 1, xtilda + 1>>{}),
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YTilda, XTilda>,
Sequence<iYTilda, iXTilda>,
Sequence<iYTilda + 1, iXTilda + 1>>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
......@@ -335,9 +330,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Sequence<2, 4>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<C, 1, 1>>{},
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -349,17 +344,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
......@@ -229,10 +229,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
......@@ -269,10 +269,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
......@@ -344,6 +344,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
constexpr auto in_block_slice_copy_steps = Sequence<EPerBlock, 0, 0, 0>{};
constexpr auto wei_block_slice_copy_steps = Sequence<EPerBlock, 0>{};
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
......@@ -366,8 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
__syncthreads();
......@@ -393,8 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
__syncthreads();
......@@ -482,14 +485,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
3,
1,
1,
AddressSpace::vgpr,
AddressSpace::global,
InMemoryDataOperation::none>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
AddressSpace::Vgpr,
AddressSpace::Global,
InMemoryDataOperation::Set>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
.Run(p_out_thread, p_out_global);
}
}
......
......@@ -94,9 +94,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
integral_constant<AddressSpace, AddressSpace::Generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
integral_constant<AddressSpace, AddressSpace::Global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
......@@ -141,13 +141,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
static_assert(
(Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! alignment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
......@@ -357,37 +358,49 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
// LDS double buffer: tail
{
// even iteration
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
if(has_two_iteration_left) // if has 2 iteration left
{
// even iteration
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
__syncthreads();
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
}
}
// copy output: register to global memory
......
......@@ -25,15 +25,15 @@ template <index_t GridSize,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmM,
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
......@@ -130,19 +130,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
decltype(wei_e_k_global_desc),
decltype(in_e_b_global_desc),
decltype(out_k_b_global_desc),
InMemoryDataOperation::none,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
ThreadGemmAThreadCopySrcDataPerRead_GemmM,
ThreadGemmAThreadCopySrcDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
......
......@@ -251,9 +251,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.template Run<Float, AddressSpace::global>(p_in_global,
blockwise_in_copy.template Run<Float, AddressSpace::Global>(p_in_global,
p_in_block_double);
blockwise_wei_copy.template Run<Float, AddressSpace::global>(p_wei_global,
blockwise_wei_copy.template Run<Float, AddressSpace::Global>(p_wei_global,
p_wei_block_double);
}
......@@ -285,9 +285,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
......@@ -311,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
......@@ -390,7 +390,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{
threadwise_out_copy
.template Run<Float, AddressSpace::generic, AddressSpace::global>(p_out_thread,
.template Run<Float, AddressSpace::Generic, AddressSpace::Global>(p_out_thread,
p_out_global);
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
......
......@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto
template <typename... Ts>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
{
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
......
......@@ -267,7 +267,7 @@ struct TensorCoordinate
private:
template <typename... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
{
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
......@@ -275,7 +275,7 @@ struct TensorCoordinate
template <typename... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
{
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
......
......@@ -327,14 +327,14 @@ struct TensorCoordinate_deprecated
private:
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
{
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
}
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
{
return MergedTensorCoordinate_deprecated<
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
......
......@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor,
index_t... LowerDimensionIds,
index_t... UpperDimensionIds>
__host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
{
return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>,
......@@ -78,7 +78,7 @@ reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
// reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
......@@ -96,7 +96,7 @@ reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLo
// reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
......@@ -152,9 +152,9 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
// sanity-checknfoldable
// sanity-check if unfold-able
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
"wrong! not unfoldable");
"wrong! not unfold-able");
// unfolded length, stride
constexpr index_t unfold_length =
......
......@@ -23,8 +23,8 @@ template <index_t BlockSize,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t KPerThreadLoop,
index_t DataPerReadA,
index_t DataPerReadB>
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
......@@ -150,13 +150,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
DataPerReadA>{};
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
DataPerReadB>{};
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
......@@ -238,13 +238,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
DataPerReadA>{};
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
DataPerReadB>{};
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
......
......@@ -9,7 +9,7 @@
namespace ck {
// This threadwise copy allow vector access of src and dst.
// This blockwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst.
......@@ -28,10 +28,10 @@ template <index_t BlockSize,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
struct BlockwiseGenericTensorSliceCopy_v4
{
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
......@@ -115,7 +115,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{
static_assert(ThreadBufferAddressSpace == AddressSpace::vgpr,
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
"wrong! This function use vgpr as its thread "
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
......@@ -157,7 +157,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
1,
SrcAddressSpace,
ThreadBufferAddressSpace,
InMemoryDataOperation::none>;
InMemoryDataOperation::Set>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
BlockDstDesc,
......
......@@ -499,7 +499,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
ThreadBufferData* p_thread_buffer) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
integral_constant<AddressSpace, AddressSpace::Generic>{};
RunLoadThreadBuffer(
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
......@@ -529,7 +529,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
BlockDstData* p_block_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
integral_constant<AddressSpace, AddressSpace::Generic>{};
RunStoreThreadBuffer(
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
......@@ -548,7 +548,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
integral_constant<AddressSpace, AddressSpace::Generic>{};
RunLoadThreadBuffer(
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
......@@ -562,7 +562,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
integral_constant<AddressSpace, AddressSpace::Generic>{};
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
}
......
......@@ -22,15 +22,15 @@ template <index_t GridSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t ThreadGemmDataPerReadM,
index_t ThreadGemmDataPerReadN,
index_t ThreadGemmAThreadCopySrcDataPerRead_M,
index_t ThreadGemmBThreadCopySrcDataPerRead_N,
typename ABlockCopyThreadSliceLengths_K_M,
typename ABlockCopyThreadClusterLengths_K_M,
typename ABlockCopyThreadClusterArrangeOrder,
......@@ -54,8 +54,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
{
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -101,8 +101,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
......@@ -139,10 +139,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
1,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0});
// B matrix in LDS memory, dst of blockwise copy
......@@ -165,10 +165,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
1,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0});
// GEMM definition
......@@ -181,35 +181,33 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check
static_assert(MPerBlock % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
MPerBlock / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat =
NPerBlock / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThreadSubC>{}, Number<GemmNRepeat * NPerThreadSubC>{});
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThreadSubC,
NPerThreadSubC,
MPerThread,
NPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
KPerThreadLoop,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN>{};
KPerThread,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
......@@ -233,6 +231,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
......@@ -255,8 +256,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads();
......@@ -282,8 +283,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads();
......@@ -317,16 +318,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// input: register to global memory
{
constexpr index_t M1 = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThreadSubC, GemmNRepeat, NPerThreadSubC>{});
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
......@@ -352,8 +353,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
CThreadCopySrcDstVectorReadWriteDim,
1,
CThreadCopyDstDataPerWrite,
AddressSpace::vgpr,
AddressSpace::global,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
......
......@@ -21,9 +21,9 @@ template <typename SrcDesc,
index_t SrcDstVectorReadWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
struct ThreadwiseGenericTensorSliceCopy_v4r2
{
static constexpr index_t nDim = SliceLengths::Size();
......@@ -115,8 +115,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
}
}
......@@ -146,7 +146,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
......@@ -265,12 +265,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(p_src,
src_nonlinear_coord.GetOffset() +
src_linear_offset,
p_src_long_vector,
buffer_offset);
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(p_src,
src_nonlinear_coord.GetOffset() +
src_linear_offset,
p_src_long_vector,
buffer_offset);
}
}
......@@ -303,7 +303,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
......@@ -404,8 +404,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
AddressSpace::Vgpr,
InMemoryDataOperation::Set>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
}
}
......@@ -448,7 +448,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp>(p_dst_long_vector,
buffer_offset,
......
......@@ -333,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// 3. src_merged_offset can be runtime value (no assumption imposed)
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto fwd) {
static_if<SrcAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
#if CK_USE_AMD_BUFFER_ADDRESSING
vector_data = amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>(
fwd(p_src), src_merged_offset, src_normal_offset);
......@@ -442,7 +442,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// 3. dst_merged_offset can be runtime value (no assumption imposed)
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto fwd) {
static_if<DstAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>(
vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset);
......@@ -464,7 +464,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
integral_constant<AddressSpace, AddressSpace::Generic>{};
Run(p_src, p_dst, generic_address_space, generic_address_space);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment