Commit a3230a64 authored by Chao Liu's avatar Chao Liu
Browse files

cosmetic fix

parent bf948337
...@@ -11,8 +11,8 @@ template <typename T, ...@@ -11,8 +11,8 @@ template <typename T,
typename OutDesc, typename OutDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename InLeftPads,
typename RightPads> typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw, Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc, WeiDesc wei_kcyx_desc,
...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
const Tensor<T>& out_nkhw, const Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
std::size_t nrepeat) std::size_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -62,14 +62,14 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -62,14 +62,14 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; // Gemm-K, Gemm-M using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; // Gemm-K, Gemm-M using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; // Gemm-M constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; // Gemm-M constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; // Gemm-K, Gemm-N using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
...@@ -80,8 +80,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -80,8 +80,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmM = C * Y * X; constexpr index_t GemmM = C * Y * X;
constexpr index_t GemmN = N * Ho * Wo; constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -95,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i ...@@ -95,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
......
...@@ -11,8 +11,8 @@ template <typename T, ...@@ -11,8 +11,8 @@ template <typename T,
typename OutDesc, typename OutDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename InLeftPads,
typename RightPads> typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw, Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc, WeiDesc wei_kcyx_desc,
...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
const Tensor<T>& out_nkhw, const Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
std::size_t nrepeat) std::size_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -101,8 +101,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -101,8 +101,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda; constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -116,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -116,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
......
...@@ -11,8 +11,8 @@ template <class T, ...@@ -11,8 +11,8 @@ template <class T,
class OutDesc, class OutDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
class LeftPads, class InLeftPads,
class RightPads> class InRightPads>
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
...@@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -181,10 +181,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -181,10 +181,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
#endif #endif
constexpr index_t B = N * Ho * Wo; constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
((B + GemmNPerBlock - 1) / GemmNPerBlock) * ((K + GemmMPerBlock - 1) / GemmMPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
...@@ -198,8 +199,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -198,8 +199,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, InLeftPads,
RightPads, InRightPads,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment