Unverified Commit ac9e01e2 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Clean-up the headers (#713)



* fix headers for gpu instances

* remove unused headers

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent 76ec0089
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// Number of GEMMs = YTilde * XTilde
// GemmM = C
// GemmN = N * HTildeSlice * WTildeSlice
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t IYTildeValue,
index_t IXTildeValue,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<IYTildeValue>,
Number<IXTildeValue>,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilde = Number<IYTildeValue>{};
constexpr auto IXTilde = Number<IXTildeValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde);
const auto K1 = GemmK1;
const auto K0 = K / K1;
// weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilde),
make_freeze_transform(IXTilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
#if 1
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// output tensor
// this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
#if 1
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(IXTilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_pass_through_transform(C),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: out
// B: wei
// C: in
// Number of GEMMs = YTilde * XTilde
// GemmM = N * HTildeSlice * WTildeSlice
// GemmN = C
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename IYTilde,
typename IXTilde,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
IYTilde i_ytilde,
IXTilde i_xtilde,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto K1 = GemmK1;
const auto K0 = K / K1;
// A: output tensor
// this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
#if 1
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// B: weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
#if 1
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
// A: out
// B: wei
// C: in
// Number of GEMMs = 1
// GemmM = N * Ho * Wo
// GemmN = C
// GemmK = K
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const TensorDescriptor<Wei...>& /* wei_k_y_x_c_grid_desc */,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto K1 = GemmK1;
const auto K0 = K / K1;
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value,
typename GemmKBatchType,
typename GemmKPadType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>,
GemmKBatchType GemmKBatch,
GemmKPadType GemmKPad)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = C * Y * X;
const auto GemmKTotal = N * Ho * Wo;
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmktotal_gemmn_grid_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = C * Y * X;
const auto GemmK = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value,
typename GemmKBatchType,
typename GemmKPadType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>,
GemmKBatchType GemmKBatch,
GemmKPadType GemmKPad)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo;
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: output tensor
const auto out_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmK = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: output tensor
const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: out
// B: in
// C: wei
// GemmM = K
// GemmN = Y * X * C
// GemmKTotal = N * Ho * Wo
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value,
typename GemmKBatchType,
typename GemmKPadType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>,
GemmKBatchType GemmKBatch,
GemmKPadType GemmKPad)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo;
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
}
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
}
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
}
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = N * Ho * Wo;
const auto GemmN = K;
const auto GemmK = Y * X * C;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: weight tensor
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM0 = 1
// GemmM1 = K
// GemmN0 = N0
// GemmN1 = (N / N0) * Ho * Wo
// GemmK0 = (C / C0) * Y * X
// GemmK1 = C0
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename N0Type,
typename C0Type>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const N0Type& N0,
const C0Type& C0)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto N1 = N / N0;
const auto C1 = C / C0;
// weight tensor
const auto wei_gk0_gm0_gm1_gk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(C0, C1)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor(
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
make_pass_through_transform(N0),
make_merge_transform(make_tuple(N1, Ho, Wo)),
make_pass_through_transform(C0)),
make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_n_k_howo_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo));
const auto out_n0_n1_1_k_howo_grid_desc =
transform_tensor_descriptor(out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(N0),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return make_tuple(
wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatDsPointer,
typename FloatE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_bias_e_permute(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// input : A[M, K], or A[K, N]
// input : B[K, N], or A[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <typename ALayout,
typename BLayout,
typename CDELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmBiasEPermute_Xdl;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static constexpr index_t NumDTensor = 1;
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc)
{
index_t M0 = d_e_grid_desc.M0_;
index_t M1 = d_e_grid_desc.M1_;
index_t M2 = d_e_grid_desc.M2_;
index_t N0 = d_e_grid_desc.N0_;
index_t N1 = d_e_grid_desc.N1_;
index_t stride_M0 = d_e_grid_desc.stride_M0_;
index_t stride_M1 = d_e_grid_desc.stride_M1_;
index_t stride_M2 = d_e_grid_desc.stride_M2_;
index_t stride_N0 = d_e_grid_desc.stride_N0_;
index_t stride_N1 = d_e_grid_desc.stride_N1_;
const auto e_grid_desc_mraw_nraw = [&]() {
const auto e_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor(
make_tuple(M0, M1, M2, N0, N1),
make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1));
return transform_tensor_descriptor(
e_grid_desc_m0_m1_m2_n0_n1,
make_tuple(make_merge_transform(make_tuple(M0, M1, M2)),
make_merge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{}));
using DsGridDesc_M_N = Tuple<EGridDesc_M_N>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
const void* p_d_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
// populate pointer, desc for Ds
// D pointer
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid);
// D desc
ds_grid_desc_m_n_(I0) = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_[I0]);
}
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_bias_e_permute<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_d,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_d,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_d,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_d,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmBiasEPermute_Xdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
typename CGridBlockCluster_BlockId_To_GM10_GN10,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t GK0PerBlock,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs,
typename BM10BN10ThreadClusterBN10Xs,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
const index_t GM10 = GM1 / GM11;
const index_t GN10 = GN1 / GN11;
const index_t grid_size = GM10 * GN10;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
{
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
{
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
{
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
a_grid_desc_gk0_gm0_gm1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
{
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
b_grid_desc_gk0_gn0_gn1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
}
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
constexpr auto BM = GM0 * GM11;
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
BM1PerThreadBM11>{};
constexpr auto BN1 =
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
BN1PerThreadBN11>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
c_grid_desc_gm0_gm1_gn0_gn1,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)),
make_pass_through_transform(GN10),
make_merge_transform(make_tuple(GN0, GN11))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
make_pass_through_transform(GN10),
make_unmerge_transform(make_tuple(BN0, BN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
}
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_grid_block_cluster_blockid_to_gm10_gn10;
}
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, igm10, 0, 0),
a_block_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, ign10, 0, 0),
b_block_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_gk0_bm_gk1),
decltype(b_block_desc_gk0_bn_gk1),
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
BM1PerThreadBM11,
BN1PerThreadBN11>{};
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t gk0_block_on_grid = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
gk0_block_on_grid += 2 * GK0PerBlock;
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
Sequence<1,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
1,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_multi_index(igm10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
ign10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf,
CGridStepHacks{});
}
}
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AKM0M1GridDesc,
typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CMNGridDesc,
index_t MPerBlockM1,
index_t NPerBlockN1,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K == b_k_n_grid_desc.GetLength(I0)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
{
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1;
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1;
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto N11 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto cblockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m0_m1_grid_desc),
decltype(a_k_m0_m1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k_m0_m1_grid_desc,
make_multi_index(0, im0, 0),
a_k_m0_m1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n0_n1_grid_desc),
decltype(b_k_n0_n1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_k_n0_n1_grid_desc,
make_multi_index(0, in0, 0),
b_k_n0_n1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
AGridMoveSliceWindowStepHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridStepHacks{});
}
}
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalStepHacks,
typename BGlobalStepHacks,
typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto E = EPerBlock * 3 * 3;
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 0
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#endif
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e_k_global_desc),
decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global),
a_e_k_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
1,
true>(
b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainKBlockLoop)
{
index_t e_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
e_block_data_begin += 2 * EPerBlock;
} while(e_block_data_begin < E - 2 * EPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_k_n_ho_wo_global_desc,
c_global_buf,
c_k_n_ho_wo_global_tensor_step_hacks);
}
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
const auto b_e_n_ho_wo_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_V3_HPP
#define CK_GRIDWISE_GEMM_V3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::ConvBiasActiv(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum, ActivType>{});
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v3_resize_add(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_d_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::ConvBiasActivResizeAdd(p_a_grid,
p_b_grid,
p_bias_grid,
p_d_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum, ActivType>{});
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v3_maxpool(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_d_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::ConvBiasActivMaxpool(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum, ActivType>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo,
typename DGridDesc_K_N_Hx_Wx,
index_t E1_,
index_t E2_,
index_t K2_,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t E1PerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_E2,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalStepHacks,
typename BGlobalStepHacks,
typename CGlobalStepHacks,
typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{};
static constexpr auto K2 = Number<K2_>{};
static constexpr auto NPerBlock = I1;
static constexpr FloatAcc alpha = 0.3;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<E1>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = math::integer_least_multiple(
a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto H0 = Ho / HoPerBlock;
const auto W0 = Wo / WoPerBlock;
const index_t grid_size = K0 * N0 * H0 * W0;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0)
{
const bool has_main_e0_block_loop = E0 > 1;
return has_main_e0_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop()
{
const bool has_main_e1_block_loop = ((E1 + E1PerBlock) / (2 * E1PerBlock)) > 1;
return has_main_e1_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop()
{
const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0;
return has_double_tail_e1_block_loop;
}
__host__ __device__ static constexpr auto
MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc)
{
const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0);
const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor(
a_e0_e1_k_e2_grid_desc,
make_tuple(make_pass_through_transform(E0),
make_pass_through_transform(E1),
make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_e0_e1_k0_k1_e2_grid_desc;
}
__host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc)
{
const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0);
// const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1);
const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2);
const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3);
const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4);
// const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5);
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc,
make_tuple(make_pass_through_transform(E0),
make_pass_through_transform(E1),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3, 4, 5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor(
c_k_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto H2 = Number<HoPerThread / 2>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Number<Hx / (H1 * H2)>{};
const auto W2 = Number<WoPerThread / 2>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Number<Wx / (W1 * W2)>{};
#else
const auto H2 = HoPerThread / 2;
const auto H1 = HoPerBlock / HoPerThread;
const auto H0 = Hx / (H1 * H2);
const auto W2 = WoPerThread / 2;
const auto W1 = WoPerBlock / WoPerThread;
const auto W0 = Wx / (W1 * W2);
#endif
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = Number<HoPerThread * 2>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto W2 = Number<WoPerThread * 2>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto H0 = Number<Hx / (H1 * H2)>{};
const auto W0 = Number<Wx / (W1 * W2)>{};
#else
const auto H0 = Hx / (H1 * H2);
const auto W0 = Wx / (W1 * W2);
#endif
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto K0 = Number<K / KPerBlock>{};
const auto N0 = Number<N / NPerBlock>{};
const auto H0 = Number<Ho / HoPerBlock>{};
const auto W0 = Number<Wo / WoPerBlock>{};
#else
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto H0 = Ho / HoPerBlock;
const auto W0 = Wo / WoPerBlock;
#endif
const auto cblockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
return cblockid_to_k_n_ho_wo_block_cluster_adaptor;
}
// using AGridDesc_E0_E1_K0_K1_E2 =
// decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
// using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
// decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
// using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
// decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
// using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
// decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
template <typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>
__host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor(
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)
{
const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0);
const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1);
return make_naive_tensor_descriptor_packed(make_tuple(K0, K1));
}
__host__ __device__ static constexpr auto MakeCK1NH2W2ThreadDescriptor()
{
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{}));
return c_k1_n_h2_w2_thread_gemm_desc;
}
// using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor());
__host__ __device__ static constexpr auto GetBlockWiseGemm()
{
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
constexpr auto b_e1_n_h_w_e2_block_gemm_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{},
I1,
Number<HoPerBlock>{},
Number<WoPerBlock>{},
Number<E2>{}));
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e1_k1_e2_block_gemm_desc),
decltype(b_e1_n_h_w_e2_block_gemm_desc),
decltype(c_k1_n_h2_w2_thread_gemm_desc),
EPerThread,
K2>{};
return blockwise_gemm;
}
__device__ static constexpr auto GetCThreadIndex()
{
auto blockwise_gemm = GetBlockWiseGemm();
auto c_thread_mtx_index =
blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
return c_thread_mtx_index;
};
__device__ static constexpr auto GetCBlockIndex(
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor)
{
const auto c_k_n_h_w_block_cluster_idx =
cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
return c_k_n_h_w_block_cluster_idx;
}
template <typename BiasGlobalBuff,
typename CThreadBuff,
typename CBlockIndex,
typename CThreadIndex,
typename BiasGridDesc_K0_K1,
typename CThreadDesc_K1_N_H2_W2>
__device__ static void BiasOp(BiasGlobalBuff& bias_global_buf,
CThreadBuff& c_thread_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const BiasGridDesc_K0_K1& bias_k0_k1_grid_desc,
const CThreadDesc_K1_N_H2_W2&)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const auto k_thread_id = c_thread_idx[I0];
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
constexpr auto bias_k0_k1_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
bias_k0_k1_thread_desc.GetElementSpaceSize(),
true>
bias_thread_buf;
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
auto bias_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatC,
FloatC,
decltype(bias_k0_k1_grid_desc),
decltype(bias_k0_k1_thread_desc),
Sequence<I1, Number<KPerThread>{}>,
Sequence<0, 1>,
1,
CThreadTransferDstScalarPerVector,
false,
true>(
bias_k0_k1_grid_desc, make_multi_index(k_block_work_id, k_thread_data_on_global));
constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple(
make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{}));
bias_threadwise_transfer.Run(bias_k0_k1_grid_desc,
bias_global_buf,
bias_k0_k1_thread_desc,
make_tuple(I0, I0),
bias_thread_buf,
bias_k0_k1_global_tensor_step_hacks);
static_for<0, KPerThread, 1>{}([&](auto ki) {
static_for<0, HoPerThread, 1>{}([&](auto hi) {
static_for<0, WoPerThread, 1>{}([&](auto wi) {
constexpr index_t c_offset =
c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(make_tuple(ki, 0, hi, wi));
c_thread_buf(Number<c_offset>{}) =
c_thread_buf[Number<c_offset>{}] + bias_thread_buf[ki];
});
});
});
}
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum activ_type_>
__device__ static void Activation(CThreadBuff& c_thread_buf,
const CThreadDesc_K1_N_H2_W2&,
integral_constant<ActivTypeEnum, activ_type_>)
{
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type_ == 1)
{
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
}
else if constexpr(activ_type_ == 2)
{
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
asm volatile("\n \
v_rcp_f32 %0, %1 \n"
: "=v"(x)
: "0"(x));
c_thread_buf(i) = x;
}
});
}
template <typename CThreadBuff,
typename CGlobalBuff,
typename CBlockIndex,
typename CThreadIndex,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>
__device__ static void
WriteOut(const CThreadBuff& c_thread_buf,
CGlobalBuff& c_global_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
const auto k_thread_id = c_thread_idx[I0];
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{}));
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread, I1, I1, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_global_buf,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks);
}
template <typename CThreadBuff,
typename DGlobalBuff,
typename CBlockIndex,
typename CThreadIndex,
typename CThreadDesc_K1_N_H2_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>
__device__ static void
MaxPool(const CThreadBuff& c_thread_buf,
DGlobalBuff& d_global_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const CThreadDesc_K1_N_H2_W2&,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
const auto k_thread_id = c_thread_idx[I0];
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, "");
constexpr auto HoPerThread_2 = HoPerThread / 2;
constexpr auto WoPerThread_2 = WoPerThread / 2;
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThread_2>{},
I1,
I1,
Number<WoPerThread_2>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
d_thread_buf;
static_for<0, KPerThread, 1>{}([&](auto ki) {
static_for<0, HoPerThread_2, 1>{}([&](auto hi) {
static_for<0, WoPerThread_2, 1>{}([&](auto wi) {
constexpr index_t d_offset =
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset(
make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi));
constexpr index_t c_offset_0 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2, wi * 2));
constexpr index_t c_offset_1 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2, wi * 2 + 1));
constexpr index_t c_offset_2 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2 + 1, wi * 2));
constexpr index_t c_offset_3 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1));
d_thread_buf(Number<d_offset>{}) = c_thread_buf[Number<c_offset_0>{}];
d_thread_buf(Number<d_offset>{}) =
fmaxf(c_thread_buf[Number<c_offset_1>{}], d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) =
fmaxf(c_thread_buf[Number<c_offset_2>{}], d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) =
fmax(c_thread_buf[Number<c_offset_3>{}], d_thread_buf(Number<d_offset>{}));
});
});
});
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{};
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread_2, I1, I1, WoPerThread_2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_global_buf,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
}
template <typename CThreadBuff,
typename DGlobalBuff,
typename CBlockIndex,
typename CThreadIndex,
typename CThreadDesc_K1_N_H2_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>
__device__ static void
ResizeAdd(const CThreadBuff& c_thread_buf,
DGlobalBuff& d_global_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const CThreadDesc_K1_N_H2_W2&,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
const auto k_thread_id = c_thread_idx[I0];
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThreadx2>{},
I1,
I1,
Number<WoPerThreadx2>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
d_thread_buf;
static_for<0, KPerThread, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
d_thread_buf(Number<d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset(
make_tuple(0, k_i, 0, 0, 0, h_i, 0, 0, w_i))>{}) =
c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}];
});
});
});
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{};
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum::Add,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_global_buf,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
}
template <typename AGlobalBuff,
typename BGlobalBuff,
typename CThreadBuff,
typename CBlockIndex,
typename CThreadIndex,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CThreadDesc_K1_N_H2_W2,
bool HasMainE0BlockLoop>
__device__ static void
GemmOp(const AGlobalBuff& a_global_buf,
const BGlobalBuff& b_global_buf,
CThreadBuff& c_thread_buf,
FloatAB* __restrict__ p_shared_block,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CThreadDesc_K1_N_H2_W2&,
integral_constant<bool, HasMainE0BlockLoop>)
{
constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop();
constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop();
// const auto c_k_n_h_w_block_cluster_idx =
// GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
// cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex(
// make_multi_index(get_block_1d_id()));
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
constexpr auto b_e1_n_h_w_e2_block_gemm_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{},
I1,
Number<HoPerBlock>{},
Number<WoPerBlock>{},
Number<E2>{}));
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e1_k1_e2_block_gemm_desc),
decltype(b_e1_n_h_w_e2_block_gemm_desc),
decltype(c_k1_n_h2_w2_thread_gemm_desc),
EPerThread,
K2>{};
// blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<I1>{}, Number<E1>{}, I1, Number<KPerBlock>{}, Number<E2>{}),
max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<I1, E1, I1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e0_e1_k0_k1_e2_grid_desc),
decltype(a_e0_e1_k0_k1_e2_block_copy_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorDim,
4,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_E2,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
false>(a_e0_e1_k0_k1_e2_grid_desc,
make_multi_index(0, 0, k_block_work_id, 0, 0),
a_e0_e1_k0_k1_e2_block_copy_desc,
make_multi_index(0, 0, 0, 0, 0));
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0);
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<E1PerBlock>{},
I1,
I1,
I1,
Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{},
Number<E2>{}));
auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc),
Sequence<I1, E1PerBlock, I1, I1, I1, HoPerThread, I1, I1, WoPerThread, E2>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
make_multi_index(0,
0,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0,
0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize());
//// register allocation for output
// StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc,
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
// true>
// c_thread_buf;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k1_n_h2_w2_thread_gemm_desc),
Sequence<KPerThread, I1, HoPerThread, WoPerThread>>{}
.Run(c_k1_n_h2_w2_thread_gemm_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto b_thread_slice_copy_step =
make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf;
if constexpr(HasMainE0BlockLoop)
{
const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0);
index_t e0_block_data_begin = 0;
do
{
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainE1BlockLoop)
{
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
e1_block_data_begin += 2 * E1PerBlock;
} while(e1_block_data_begin < E1 - 2 * E1PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc,
a_block_slice_copy_step,
AGlobalMoveSliceWindowStepHacks{});
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
e0_block_data_begin += 1;
} while(e0_block_data_begin < E0);
}
else
{
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainE1BlockLoop)
{
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
e1_block_data_begin += 2 * E1PerBlock;
} while(e1_block_data_begin < E1 - 2 * E1PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
}
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void
Conv(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__device__ static void ConvBiasActiv(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__device__ static void ConvBiasActivMaxpool(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
// MaxPool
MaxPool(c_thread_buf,
d_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__device__ static void ConvBiasActivResizeAdd(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Resize_Add
ResizeAdd(c_thread_buf,
d_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
}
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "static_tensor.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename Dst0Desc,
typename Dst1Desc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct ThreadwiseTensorSliceTransfer_v3r3
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r3(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)),
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)),
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc,
const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx);
dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
src_vector_container.template AsType<SrcData>()(i) =
src_element_op_(src_vector_container.template AsType<SrcData>()[i]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch()
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0)
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert(SrcVectorDim != DstVectorDim, "wrong");
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
// TODO type_convert is not used yet!!!!!
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return src_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
});
}
#endif
}
template <typename DstBuffer, typename Dst0Buffer, typename Dst1Buffer>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf,
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf)
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch();
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make forward steps: dst0
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst0_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst0_desc, forward_step_idx);
},
Number<nDim>{});
// make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst1_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst1_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// make backward steps: dst0
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst0_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst0_desc, backward_step_idx);
},
Number<nDim>{});
// make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst1_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst1_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
// apply DstElementwiseOperation on dst_vector_container
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
dst_vector_container.template AsType<DstData>()(i) =
dst_element_op_(dst_vector_container.template AsType<DstData>()[i]);
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
// TODO: BUG: should start at 1
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Dst0Desc dst0_desc,
const Dst1Desc dst1_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step);
move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>
src_thread_scratch_;
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(dst_thread_scratch_desc_),
true>
dst_thread_scratch_;
SrcCoord src_coord_;
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
#include "data_type.hpp"
namespace ck {
__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment