Commit a3230a64 authored by Chao Liu's avatar Chao Liu
Browse files

cosmetic fix

parent bf948337
......@@ -11,8 +11,8 @@ template <typename T,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads>
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
......@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
using namespace ck;
......@@ -62,14 +62,14 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; // Gemm-K, Gemm-M
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; // Gemm-K, Gemm-M
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; // Gemm-M
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; // Gemm-M
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; // Gemm-K, Gemm-N
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; // Gemm-K, Gemm-N
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
......@@ -80,8 +80,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmM = C * Y * X;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......@@ -95,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
......
......@@ -11,8 +11,8 @@ template <typename T,
typename OutDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads>
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
......@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
InLeftPads,
InRightPads,
std::size_t nrepeat)
{
using namespace ck;
......@@ -101,8 +101,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......@@ -116,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
......
......@@ -11,8 +11,8 @@ template <class T,
class OutDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads>
class InLeftPads,
class InRightPads>
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
......@@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
......@@ -181,10 +181,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
#endif
constexpr index_t B = N * Ho * Wo;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize =
((B + GemmNPerBlock - 1) / GemmNPerBlock) * ((K + GemmMPerBlock - 1) / GemmMPerBlock);
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......@@ -198,8 +199,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment