"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "4d84f886e7acd6504233002076d711acd5c7eb9c"
Unverified Commit 3406a114 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Update for recent MIOpen integration (#11)

* update for MIOpen integration
parent c5da0377
...@@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
const Float* __restrict__ p_wei_global, const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const const Float* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
...@@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
"be violated"); "be violated");
// output tensor // output tensor
constexpr auto out_n_k_howo_global_desc =
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
constexpr auto out_k_b_global_desc = constexpr auto out_k_b_global_desc =
transform_tensor_descriptor(out_n_k_howo_global_desc, transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}), make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
} }
{ {
#if 1 // debug #if 1 // debug
// input: register to global memory, atomic add // input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none ? InMemoryDataOperation::none
......
...@@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"be violated"); "be violated");
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
...@@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
Embed<Y, Embed<Y,
Sequence<Ydot, Ytilda>, Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{}, Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>>{},
Embed<X, Embed<X,
Sequence<Xdot, Xtilda>, Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}), Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
Embed<Ho, Embed<Ho,
Sequence<Ydot, Htilda>, Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{}, Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{},
Embed<Wo, Embed<Wo,
Sequence<Xdot, Wtilda>, Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}), Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
...@@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
"be violated"); "be violated");
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
...@@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
Embed<Y, Embed<Y,
Sequence<Ydot, Ytilda>, Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>, Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{}, wei_skip_all_out_of_bound_check>{},
Embed<X, Embed<X,
Sequence<Xdot, Xtilda>, Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>, Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}), wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
Embed<Ho, Embed<Ho,
Sequence<Ydot, Htilda>, Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>, Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{}, out_skip_all_out_of_bound_check>{},
Embed<Wo, Embed<Wo,
Sequence<Xdot, Wtilda>, Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>, Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}), out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
...@@ -22,8 +22,6 @@ template <index_t GridSize, ...@@ -22,8 +22,6 @@ template <index_t GridSize,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads, typename InRightPads,
index_t Iter_ytilda,
index_t Iter_xtilda,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -47,9 +45,27 @@ template <index_t GridSize, ...@@ -47,9 +45,27 @@ template <index_t GridSize,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1> index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{ {
__device__ void Run(Float* __restrict__ p_in_global, __host__ __device__ static constexpr index_t GetNumberOfGemm()
const Float* __restrict__ p_wei_global, {
const Float* __restrict__ p_out_global) const constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
return Ytilda * Xtilda;
}
template <index_t iYTilda, index_t iXTilda>
__device__ static void RunImpl(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global)
{ {
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
...@@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
"be violated"); "be violated");
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
...@@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
Embed<Y, Embed<Y,
Sequence<Ydot, Ytilda>, Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>, Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{}, wei_skip_all_out_of_bound_check>{},
Embed<X, Embed<X,
Sequence<Xdot, Xtilda>, Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>, Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}), wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<K>{}, PassThrough<K>{},
Embed<Ho, Embed<Ho,
Sequence<Ydot, Htilda>, Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>, Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{}, out_skip_all_out_of_bound_check>{},
Embed<Wo, Embed<Wo,
Sequence<Xdot, Wtilda>, Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>, Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}), out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMM // GEMM
constexpr index_t ytilda = Iter_ytilda; constexpr index_t ytilda = iYTilda;
constexpr index_t xtilda = Iter_xtilda; constexpr index_t xtilda = iXTilda;
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot; constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot; constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
...@@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
} }
template <index_t GemmId>
__device__ static void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global)
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t iYTilda = GemmId / Xtilda;
constexpr index_t iXTilda = GemmId % Xtilda;
static_assert(iYTilda < Ytilda && iXTilda < Xtilda, "wrong! iYtilda, iXtilda");
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
}
}; };
} // namespace ck } // namespace ck
......
...@@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
...@@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// output tensor // output tensor
constexpr auto out_k_b_global_desc = constexpr auto out_k_b_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc, transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}), make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM // GEMM
......
...@@ -47,6 +47,9 @@ struct PassThrough ...@@ -47,6 +47,9 @@ struct PassThrough
} }
}; };
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// LowerLengths: Sequence<...> // LowerLengths: Sequence<...>
template <typename LowerLengths, template <typename LowerLengths,
typename LeftPads, typename LeftPads,
...@@ -92,12 +95,12 @@ struct Pad ...@@ -92,12 +95,12 @@ struct Pad
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
#if 1 // debug // skip valid check if user request it
if(SkipIsValidCheck) if(SkipIsValidCheck)
{ {
return true; return true;
} }
#endif
bool flag = true; bool flag = true;
for(index_t i = 0; i < nDim; ++i) for(index_t i = 0; i < nDim; ++i)
...@@ -384,6 +387,9 @@ struct UnMerge ...@@ -384,6 +387,9 @@ struct UnMerge
} }
}; };
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// UpperLengths: Sequence<...> // UpperLengths: Sequence<...>
// Coefficients: Sequence<...> // Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
...@@ -442,12 +448,12 @@ struct Embed ...@@ -442,12 +448,12 @@ struct Embed
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
#if 1 // debug // skip valid check if user request it
if(SkipIsValidCheck) if(SkipIsValidCheck)
{ {
return true; return true;
} }
#endif
bool flag = true; bool flag = true;
index_t ncorner = 1; index_t ncorner = 1;
......
...@@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<SrcData, transfer_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::none>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
} }
} }
...@@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<DstData, transfer_data<DstData,
DstDataPerWrite, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
} }
} }
...@@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<SrcData, transfer_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>(p_src, InMemoryDataOperation::none>(p_src,
src_nonlinear_coord.GetOffset() + src_nonlinear_coord.GetOffset() +
src_linear_offset, src_linear_offset,
p_src_long_vector, p_src_long_vector,
buffer_offset); buffer_offset);
} }
} }
...@@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<DstData, transfer_data<DstData,
DstDataPerWrite, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
} }
} }
...@@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<SrcData, transfer_data<SrcData,
SrcDataPerRead, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::vgpr, AddressSpace::vgpr,
InMemoryDataOperation::none>( InMemoryDataOperation::none>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
} }
} }
...@@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation // has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{ {
move_data<DstData, transfer_data<DstData,
DstDataPerWrite, DstDataPerWrite,
AddressSpace::vgpr, AddressSpace::vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>(p_dst_long_vector, DstInMemOp>(p_dst_long_vector,
buffer_offset, buffer_offset,
p_dst, p_dst,
dst_nonlinear_coord.GetOffset() + dst_linear_offset); dst_nonlinear_coord.GetOffset() +
dst_linear_offset);
} }
} }
}); });
......
...@@ -8,19 +8,12 @@ namespace ck { ...@@ -8,19 +8,12 @@ namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j]) // outer-product: c[i,j] += inner_product(a[i], b[j])
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{ {
// disable inline asm due to the compiler issue: SWDEV-202749
///\to-do: enable the inline asm after the compiler fix
#if CK_WORKAROUND_SWDEV_202749
c0 += a * b0;
c1 += a * b1;
#else
asm volatile("\n \ asm volatile("\n \
v_mac_f32 %0, %2, %3 \n \ v_mac_f32 %0, %2, %3 \n \
v_mac_f32 %1, %2, %4 \n \ v_mac_f32 %1, %2, %4 \n \
" "
: "=v"(c0), "=v"(c1) : "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#endif
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // outer-product: c[i,j] += inner_product(a[i], b[j])
......
...@@ -43,6 +43,10 @@ ...@@ -43,6 +43,10 @@
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0 #define CK_USE_AMD_XDLOPS_INLINE_ASM 0
#endif #endif
#ifndef CK_USE_AMD_XDLOPS_EMULATE
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
#endif
// experimental implementation // experimental implementation
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1 #define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0 #define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
...@@ -51,9 +55,6 @@ ...@@ -51,9 +55,6 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
// workaround
#define CK_WORKAROUND_SWDEV_202749 1
namespace ck { namespace ck {
enum AddressSpace enum AddressSpace
......
...@@ -70,7 +70,7 @@ template <typename T, ...@@ -70,7 +70,7 @@ template <typename T,
AddressSpace SrcAddressSpace, AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace, AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp> InMemoryDataOperation DstInMemOp>
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{ {
static_assert(DstInMemOp == InMemoryDataOperation::none || static_assert(DstInMemOp == InMemoryDataOperation::none ||
DstInMemOp == InMemoryDataOperation::atomic_add, DstInMemOp == InMemoryDataOperation::atomic_add,
......
...@@ -38,7 +38,7 @@ template <typename T, ...@@ -38,7 +38,7 @@ template <typename T,
AddressSpace SrcAddressSpace, AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace, AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp> InMemoryDataOperation DstInMemOp>
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{ {
static_assert(DstInMemOp == InMemoryDataOperation::none || static_assert(DstInMemOp == InMemoryDataOperation::none ||
DstInMemOp == InMemoryDataOperation::atomic_add, DstInMemOp == InMemoryDataOperation::atomic_add,
......
...@@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs) ...@@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
return x < y ? x : y; return x < y ? x : y;
} }
// highest common factor // greatest common divisor, aka highest common factor
template <typename T> template <typename T>
__host__ __device__ constexpr T hcf(T x, T y) __host__ __device__ constexpr T gcd(T x, T y)
{ {
if(x == 0) if(x == 0)
{ {
...@@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y) ...@@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y)
if(x > y) if(x > y)
{ {
return hcf(x - y, y); return gcd(x - y, y);
} }
return hcf(x, y - x); return gcd(x, y - x);
} }
template <index_t X, index_t Y> template <index_t X, index_t Y>
__host__ __device__ constexpr auto hcf(Number<X>, Number<Y>) __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
{ {
constexpr auto result = hcf(X, Y); constexpr auto result = gcd(X, Y);
return Number<result>{}; return Number<result>{};
} }
template <typename X, typename... Ys> template <typename X, typename... Ys>
__host__ __device__ constexpr auto hcf(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
return hcf(x, ys...); return gcd(x, ys...);
} }
// least common multiple // least common multiple
template <typename T> template <typename T>
__host__ __device__ constexpr T lcm(T x, T y) __host__ __device__ constexpr T lcm(T x, T y)
{ {
return (x * y) / hcf(x, y); return (x * y) / gcd(x, y);
} }
template <typename X, typename Y, typename... Zs> template <typename X, typename Y, typename... Zs>
......
...@@ -152,11 +152,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -152,11 +152,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......
...@@ -91,11 +91,11 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i ...@@ -91,11 +91,11 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......
...@@ -2,13 +2,18 @@ ...@@ -2,13 +2,18 @@
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "tensor.hpp" #include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace launcher { namespace launcher {
using namespace ck; using namespace ck;
template <typename GridwiseOp, index_t GemmId, typename... Xs>
__global__ void run_gridwise_convolution_backward_data_v4r1(Xs... xs)
{
GridwiseOp::template Run<GemmId>(xs...);
}
template <typename T, template <typename T,
typename InDesc, typename InDesc,
typename WeiDesc, typename WeiDesc,
...@@ -119,11 +124,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -119,11 +124,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif #endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
...@@ -154,69 +159,61 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i ...@@ -154,69 +159,61 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
KernelTimer timer; using GridwiseConv = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>;
KernelTimer timer;
timer.Start(); timer.Start();
static_for<0, Ytilda, 1>{}([&](auto ytilda_) { static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) { constexpr index_t gemm_id = decltype(gemm_id_){};
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){}; launch_kernel(run_gridwise_convolution_backward_data_v4r1<GridwiseConv,
gemm_id,
constexpr auto gridwise_conv = T* const __restrict__,
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< const T* const __restrict__,
GridSize, const T* const __restrict__>,
BlockSize, dim3(GridSize),
T, dim3(BlockSize),
T, 0,
decltype(in_nchw_desc), 0,
decltype(wei_kcyx_desc), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
decltype(out_nkhw_desc), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
ConvStrides, static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
ConvDilations,
InLeftPads,
InRightPads,
ytilda,
xtilda,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
});
}); });
timer.End(); timer.End();
float time = timer.GetElapsedTime(); float time = timer.GetElapsedTime();
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
......
...@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// BlockSize = 256, EperBlock = 8, each thread hold 64 data // BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -127,7 +127,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -127,7 +127,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 1
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
// for 1x1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 1 #elif 1
// BlockSize = 64, each thread hold 64 data // BlockSize = 64, each thread hold 64 data
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 0
// BlockSize = 256, GemmKPerBlock = 16 // BlockSize = 256, GemmKPerBlock = 16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -117,7 +117,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -117,7 +117,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 0
// BlockSize = 256, GemmKPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter, 8x8 image // for 1x1 filter, vector-read-b = 4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -149,7 +149,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -149,7 +149,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 1
// BlockSize = 256, GemmKPerBlock = 16 // BlockSize = 256, GemmKPerBlock = 16
// 1x1 filter, 8x8 image // for 1x1 filter, vector-read-b = 4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
......
...@@ -161,10 +161,10 @@ int main(int argc, char* argv[]) ...@@ -161,10 +161,10 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -246,28 +246,28 @@ int main(int argc, char* argv[]) ...@@ -246,28 +246,28 @@ int main(int argc, char* argv[])
#endif #endif
} }
#if 0 #if 1
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 1 #elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw, out_nkhw,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
if(do_verification) if(do_verification)
{ {
......
...@@ -29,13 +29,13 @@ int main(int argc, char* argv[]) ...@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
// 1x1 // 1x1
constexpr index_t N = 256; constexpr index_t N = 64;
constexpr index_t C = 1024; constexpr index_t C = 64;
constexpr index_t HI = 8; constexpr index_t HI = 56;
constexpr index_t WI = 8; constexpr index_t WI = 56;
constexpr index_t K = 1024; constexpr index_t K = 256;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment