Unverified Commit 38a90b6e authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #43 from ROCmSoftwarePlatform/develop

Merge develop into master
parents 88833bd9 c3018794
...@@ -21,8 +21,8 @@ template <typename... Wei, ...@@ -21,8 +21,8 @@ template <typename... Wei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads, typename InRightPads,
index_t IYTildaValue, typename IYTilda,
index_t IXTildaValue, typename IXTilda,
index_t GemmK1Value> index_t GemmK1Value>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
Number<IYTildaValue>, IYTilda i_ytilda,
Number<IXTildaValue>, IXTilda i_xtilda,
Number<GemmK1Value>) Number<GemmK1Value>)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{}; constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilda = Number<IYTildaValue>{};
constexpr auto IXTilda = Number<IXTildaValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
...@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
const auto K1 = GemmK1; const auto K1 = GemmK1;
const auto K0 = K / K1; const auto K0 = K / K1;
...@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda), make_freeze_transform(i_ytilda),
make_freeze_transform(IXTilda), make_freeze_transform(i_xtilda),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda), make_freeze_transform(i_ytilda),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_freeze_transform(IXTilda), make_freeze_transform(i_xtilda),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( ...@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
in_gemmm_gemmn_grid_desc); in_gemmm_gemmn_grid_desc);
} }
// A: out
// B: wei
// C: in
// Number of GEMMs = 1
// GemmM = N * Ho * Wo
// GemmN = C
// GemmK = K
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const TensorDescriptor<Wei...>& /* wei_k_y_x_c_grid_desc */,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto K1 = GemmK1;
const auto K0 = K / K1;
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value,
typename GemmKBatchType,
typename GemmKPadType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>,
GemmKBatchType GemmKBatch,
GemmKPadType GemmKPad)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = C * Y * X;
const auto GemmKTotal = N * Ho * Wo;
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmktotal_gemmn_grid_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value,
typename GemmKBatchType,
typename GemmKPadType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>,
GemmKBatchType GemmKBatch,
GemmKPadType GemmKPad)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo;
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: output tensor
const auto out_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmK = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: output tensor
const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: out
// B: in
// C: wei
// GemmM = K
// GemmN = Y * X * C
// GemmKTotal = N * Ho * Wo
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value,
typename GemmKBatchType,
typename GemmKPadType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>,
GemmKBatchType GemmKBatch,
GemmKPadType GemmKPad)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo;
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
// A: output tensor
const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmktotal_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform( ...@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform(
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad}; return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
} }
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck> template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
__host__ __device__ constexpr auto make_right_pad_transform( __host__ __device__ constexpr auto make_right_pad_transform(
const LowLength& low_length, const LowLength& low_length,
const RightPadLength& right_pad, const RightPadLength& right_pad,
......
...@@ -10,6 +10,7 @@ namespace ck { ...@@ -10,6 +10,7 @@ namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc, typename BK0NK1BlockDesc,
index_t MPerXDL, index_t MPerXDL,
...@@ -29,14 +30,18 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -29,14 +30,18 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t KPerBlock = K0;
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferV2<AddressSpaceEnum_t::Vgpr, vector_type<FloatAcc, 16>, MRepeat * NRepeat, true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
AK0MK1BlockDesc{}, AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}), make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform( make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})), make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
make_pass_through_transform(Number<K1>{})), make_pass_through_transform(Number<K1>{})),
...@@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
BK0NK1BlockDesc{}, BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}), make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform( make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})), make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
make_pass_through_transform(Number<K1>{})), make_pass_through_transform(Number<K1>{})),
...@@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
vector_type<FloatAB, K1> a_thread_vec; static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) {
// read A // read A
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(k0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
// read B static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, // read B
make_tuple(k0, I0, I0, I0, I0), b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
b_block_buf, make_tuple(I0, n0, I0, I0, I0),
b_thread_desc_, b_block_buf,
make_tuple(I0, I0, I0, I0, I0), b_thread_desc_,
b_thread_buf); make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
using mfma_input_type = typename vector_type<FloatAB, xdlops_gemm.KPerThread>::type; static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) {
vector_type<FloatAB, K1> a_thread_vec;
vector_type<FloatAB, K1> b_thread_vec;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, K1, 1>{}([&](auto i) { static_for<0, K1, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, m0, 0, 0, i))>{}]; [Number<a_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
});
static_for<0, K1, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, n0, 0, 0, i))>{}]; [Number<b_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
}); });
constexpr index_t c_offset = using mfma_input_type =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0));
xdlops_gemm.template Run<c_offset>( xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
a_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), c_thread_buf.GetVector(Number<c_offset>{}));
c_thread_buf);
}); });
}); });
}); });
...@@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_desc_ =
make_tuple(I1, Number<MRepeat>{}, I1, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto b_thread_desc_ =
make_tuple(I1, Number<NRepeat>{}, I1, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<xdlops_gemm.GetNumXdlops()>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_m2_k1_block_desc), decltype(a_k0_m0_m1_m2_k1_block_desc),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, 1, K1>, Sequence<K0, 1, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>, Sequence<0, 1, 2, 3, 4>,
4, 4,
K1, K1,
1>; K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
decltype(b_k0_n0_n1_n2_k1_block_desc), decltype(b_k0_n0_n1_n2_k1_block_desc),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, 1, K1>, Sequence<K0, 1, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>, Sequence<0, 1, 2, 3, 4>,
4, 4,
K1, K1,
1>; K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
......
...@@ -29,7 +29,7 @@ __global__ void ...@@ -29,7 +29,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc, const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -132,7 +132,9 @@ template <index_t BlockSize, ...@@ -132,7 +132,9 @@ template <index_t BlockSize,
typename CGridStepHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks, typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat> bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -142,6 +144,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -142,6 +144,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -151,14 +154,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -151,14 +154,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment constexpr auto a_k0_m_k1_block_desc = [&]() {
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( if constexpr(ABlockLdsExtraM)
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); {
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment constexpr auto b_k0_n_k1_block_desc = [&]() {
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( if constexpr(BBlockLdsExtraN)
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); {
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -170,29 +193,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -170,29 +193,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc) const CMNGridDesc& c_m_n_grid_desc,
index_t M01,
index_t N01)
{ {
// TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0, K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
"Invalid tuning param!"); K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
return false;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return true;
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
} }
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
...@@ -211,15 +250,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -211,15 +250,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( // A matrix in LDS memory, dst of blockwise copy
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( // B matrix in LDS memory, dst of blockwise copy
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
MPerXDL, MPerXDL,
...@@ -231,8 +295,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -231,8 +295,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
} }
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
...@@ -243,23 +308,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -243,23 +308,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
#if 1 const auto M00 = M0 / M01;
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto N00 = N0 / N01;
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}), const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_tuple(Sequence<0>{})); make_single_stage_tensor_adaptor(
#elif 1 make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
make_tuple(Sequence<1, 0>{}), c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor);
make_tuple(Sequence<0>{}));
#endif
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -294,14 +367,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -294,14 +367,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment constexpr auto a_k0_m_k1_block_desc = [&]() {
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( if constexpr(ABlockLdsExtraM)
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); {
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment constexpr auto b_k0_n_k1_block_desc = [&]() {
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( if constexpr(BBlockLdsExtraN)
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); {
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -363,9 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -363,9 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// register // register
// sanity check // sanity check
const auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
MPerXDL, MPerXDL,
...@@ -374,18 +468,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -374,18 +468,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NRepeat, NRepeat,
K1>{}; K1>{};
constexpr auto c_mr_nr_blk_desc = auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, CBlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -460,9 +543,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -460,9 +543,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -477,224 +569,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -477,224 +569,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>, Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_multi_index(0, make_multi_index(m_thread_data_on_grid_idx[I0],
0, n_thread_data_on_grid_idx[I0],
0, m_thread_data_on_grid_idx[I1],
0, n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid / (M3 * M4), m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid % (M3 * M4) / M4, m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid % M4, m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid)}; n_thread_data_on_grid_idx[I2])};
auto init_copy = [&](auto c_thread_idx_) { c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
return c_thread_idx_;
};
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
(MRepeat == 1 && NRepeat == 1),
"wrong");
if constexpr(MRepeat == 4 && NRepeat == 4)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I3));
mrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I3, I2));
nrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
else if constexpr(MRepeat == 4 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
mrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 4)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
nrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 1)
{
init_copy(make_tuple(I0, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
}
else if constexpr(MRepeat == 1 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
nrepeat_plus_copy(make_tuple(I0, I1));
}
else if constexpr(MRepeat == 1 && NRepeat == 1)
{
init_copy(make_tuple(I0, I0));
}
} }
} }
}; // namespace ck }; // namespace ck
......
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_b_k0_m_k1_grid_desc,
const void CONSTANT* p_b_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast<const ABK0MK1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc));
const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast<const BBK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc));
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
index_t M01,
index_t N01)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
return false;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
return grid_size;
}
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>;
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KBatch),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
}
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto a_b_k0_m_k1_block_desc = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<KPerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto b_b_k0_n_k1_block_desc = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<KPerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
}
}();
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}
// main body
index_t k_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2])};
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}
}
}; // namespace ck
} // namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_blockwise.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
namespace ck {
template <index_t BlockSize,
typename srcDataType,
typename dstDataType,
typename compType,
typename src2dDescType,
typename dst1dDescType,
ReduceTensorOp_t op,
NanPropagation_t nanPropaOpt,
ReduceTensorIndices_t reduceIndicesOpt,
bool isFirstCall,
bool isLastCall,
index_t GredAccessesPerThreadInBlock>
struct GridwiseReduction_xy_to_x_blockwise
{
using opReduce = typename reduce_binary_operator<compType, op>::opType;
using preUnaryOpType =
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::preUnaryOp;
using posUnaryOpType =
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::posUnaryOp;
static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<GredAccessesPerThreadInBlock>{}, Number<BlockSize>{}));
using blockwise_reduce =
BlockwiseReduction_2d_block_buffer<decltype(buffer2dDesc), true, opReduce, nanPropaOpt>;
static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize();
static constexpr auto I0 = Number<0>{};
template <int RunId>
__device__ static void Run(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global);
template <>
__device__ static void Run<1>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)ws_indices_global;
(void)indices_global;
// LDS
__shared__ compType p_in_block_buffer[BlockBufferSize];
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto in_block_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
accuValue_buf(I0) = zeroVal;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
const posUnaryOpType posUnaryOp(divider);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
constexpr auto in_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<BlockBufferSize>{}));
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
using ThreadClusterLengths = Sequence<1, BlockSize>;
auto blockwise_src_load =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, BlockBufferSize>,
ThreadSliceLengths,
ThreadClusterLengths,
Sequence<0, 1>,
srcDataType,
compType,
src2dDescType,
decltype(in_block_desc),
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
1,
1,
1,
1,
false,
true>(src2dDesc,
make_multi_index(block_global_1d_id, 0),
in_block_desc,
make_multi_index(0, 0));
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize;
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
reducedBlocks += GredAccessesPerThreadInBlock)
{
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
blockwise_src_load.RunWrite(in_block_desc, in_block_buf);
__syncthreads();
// do element-wise pre-reduction operation
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf);
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
? GredAccessesPerThreadInBlock
: toReduceBlocks - reducedBlocks;
blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0));
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
}
accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]);
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
// The first thread in the block stores the reduced result to the global location
// representing the block
if(thread_local_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
false>(dst1dDesc,
make_multi_index(block_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
false>(dst1dDesc,
make_multi_index(block_global_1d_id));
threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}
};
template <>
__device__ static void Run<2>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)ws_indices_global;
// LDS
__shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize];
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
indices_global, dst1dDesc.GetElementSpaceSize());
auto in_block_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
auto in_block_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(block_indices_buffer, BlockBufferSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
accuValue_buf(I0) = zeroVal;
accuIndex_buf(I0) = 0;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
constexpr auto in_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<BlockBufferSize>{}));
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
using ThreadClusterLengths = Sequence<1, BlockSize>;
auto blockwise_src_load =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, BlockBufferSize>,
ThreadSliceLengths,
ThreadClusterLengths,
Sequence<0, 1>,
srcDataType,
compType,
src2dDescType,
decltype(in_block_desc),
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
1,
1,
1,
1,
false,
true>(src2dDesc,
make_multi_index(block_global_1d_id, 0),
in_block_desc,
make_multi_index(0, 0));
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize;
int indexOffset = 0;
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
reducedBlocks += GredAccessesPerThreadInBlock)
{
// load block data from global to LDS, no use of double buffers (to be improved)
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf);
__syncthreads();
// construct the indices for the current toReduce blocks
blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset);
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf);
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
? GredAccessesPerThreadInBlock
: toReduceBlocks - reducedBlocks;
blockwise_reduce::Reduce2(in_block_val_buf,
in_block_idx_buf,
BlocksInOneOp,
accuValue_buf(I0),
accuIndex_buf(I0));
indexOffset += BlockBufferSize;
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
// The first thread in the block stores the reduced result to the global location
// representing the block
if(thread_local_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
false>(dst1dDesc,
make_multi_index(block_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(dst1dDesc,
dst_global_val_buf,
ReducedDataDesc,
make_tuple(I0),
priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
false>(dst1dDesc,
make_multi_index(block_global_1d_id));
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<int,
int,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
false>(dst1dDesc,
make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
};
template <>
__device__ static void Run<3>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ ws_values_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)origReduceLen;
// LDS
__shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize];
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
indices_global, dst1dDesc.GetElementSpaceSize());
auto in_block_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
auto in_block_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(block_indices_buffer, BlockBufferSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
accuValue_buf(I0) = zeroVal;
accuIndex_buf(I0) = 0;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
constexpr auto in_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<BlockBufferSize>{}));
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
using ThreadClusterLengths = Sequence<1, BlockSize>;
auto blockwise_src_val_load =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, BlockBufferSize>,
ThreadSliceLengths,
ThreadClusterLengths,
Sequence<0, 1>,
srcDataType,
compType,
src2dDescType,
decltype(in_block_desc),
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
1,
1,
1,
1,
false,
true>(src2dDesc,
make_multi_index(block_global_1d_id, 0),
in_block_desc,
make_multi_index(0, 0));
auto blockwise_src_idx_load =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, BlockBufferSize>,
ThreadSliceLengths,
ThreadClusterLengths,
Sequence<0, 1>,
int,
int,
src2dDescType,
decltype(in_block_desc),
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
1,
1,
1,
1,
false,
true>(src2dDesc,
make_multi_index(block_global_1d_id, 0),
in_block_desc,
make_multi_index(0, 0));
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize;
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
reducedBlocks += GredAccessesPerThreadInBlock)
{
// load block data from global to LDS, no use of double buffers (to be improved)
blockwise_src_val_load.RunRead(src2dDesc, src_global_val_buf);
blockwise_src_idx_load.RunRead(src2dDesc, src_global_idx_buf);
blockwise_src_val_load.RunWrite(in_block_desc, in_block_val_buf);
blockwise_src_idx_load.RunWrite(in_block_desc, in_block_idx_buf);
__syncthreads();
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
? GredAccessesPerThreadInBlock
: toReduceBlocks - reducedBlocks;
blockwise_reduce::Reduce2(in_block_val_buf,
in_block_idx_buf,
BlocksInOneOp,
accuValue_buf(I0),
accuIndex_buf(I0));
blockwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
blockwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
// The first thread in the block stores the reduced result to the global location
// representing the block
if(thread_local_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
true>(dst1dDesc,
make_multi_index(block_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(dst1dDesc,
dst_global_val_buf,
ReducedDataDesc,
make_tuple(I0),
priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(block_global_1d_id));
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<int,
int,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
};
};
} // namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <index_t BlockSize,
typename srcDataType,
typename dstDataType,
typename compType,
typename src2dDescType,
typename dst1dDescType,
ReduceTensorOp_t op,
NanPropagation_t nanPropaOpt,
ReduceTensorIndices_t reduceIndicesOpt,
bool isFirstCall,
bool isLastCall,
index_t GredThreadBufferLength>
struct GridwiseReduction_xy_to_x_direct_threadwise
{
using opReduce = typename reduce_binary_operator<compType, op>::opType;
using preUnaryOpType =
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::preUnaryOp;
using posUnaryOpType =
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::posUnaryOp;
static constexpr auto I0 = Number<0>{};
template <int RunId>
__device__ static void Run(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global);
template <>
__device__ static void Run<1>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)ws_indices_global;
(void)indices_global;
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredThreadBufferLength, true>
in_thread_buf;
using threadwise_reduce = ThreadReduce<decltype(in_thread_buf), opReduce, nanPropaOpt>;
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
accuValue_buf(I0) = zeroVal;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
const posUnaryOpType posUnaryOp(divider);
using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>;
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<GredThreadBufferLength>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
compType,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc, make_multi_index(thread_global_1d_id, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength);
for(index_t reducedLength = 0; reducedLength < toReduceLength;
reducedLength += GredThreadBufferLength)
{
threadwise_src_load.Run(
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
// do element-wise pre-reduction operation
threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
// do the reduction on the Thread Buffer
threadwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0));
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
}
accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]);
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
true>(
dst1dDesc, make_multi_index(thread_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(thread_global_1d_id));
threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
};
template <>
__device__ static void Run<2>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)ws_indices_global;
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
indices_global, dst1dDesc.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredThreadBufferLength, true>
in_thread_buf;
using threadwise_reduce = ThreadReduce<decltype(in_thread_buf), opReduce, nanPropaOpt>;
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
accuValue_buf(I0) = zeroVal;
accuIndex_buf(I0) = 0;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>;
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<GredThreadBufferLength>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
compType,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc, make_multi_index(thread_global_1d_id, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength);
index_t indexStart = 0;
for(index_t reducedLength = 0; reducedLength < toReduceLength;
reducedLength += GredThreadBufferLength)
{
threadwise_src_load.Run(
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
// do the reduction on the Thread Buffer
threadwise_reduce::Reduce2(
in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexStart);
indexStart += GredThreadBufferLength;
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
false>(
dst1dDesc, make_multi_index(thread_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
false>(dst1dDesc,
make_multi_index(thread_global_1d_id));
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<int,
int,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
false>(dst1dDesc,
make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
};
template <>
__device__ static void Run<3>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ ws_values_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)origReduceLen;
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
indices_global, dst1dDesc.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredThreadBufferLength, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, GredThreadBufferLength, true> in_thread_idx_buf;
using threadwise_reduce = ThreadReduceWithIndicesInput<decltype(in_thread_val_buf),
decltype(in_thread_idx_buf),
opReduce,
nanPropaOpt>;
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
accuValue_buf(I0) = zeroVal;
accuIndex_buf(I0) = 0;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>;
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<GredThreadBufferLength>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
compType,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc, make_multi_index(thread_global_1d_id, 0));
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<int,
int,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc, make_multi_index(thread_global_1d_id, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength);
for(index_t reducedLength = 0; reducedLength < toReduceLength;
reducedLength += GredThreadBufferLength)
{
threadwise_src_val_load.Run(src2dDesc,
src_global_val_buf,
ThreadBufferDesc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(src2dDesc,
src_global_idx_buf,
ThreadBufferDesc,
make_tuple(I0, I0),
in_thread_idx_buf);
// do the reduction on the Thread Buffer
threadwise_reduce::Reduce(
in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0));
threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
false>(
dst1dDesc, make_multi_index(thread_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
false>(dst1dDesc,
make_multi_index(thread_global_1d_id));
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<int,
int,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
false>(dst1dDesc,
make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
};
};
} // namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_warpwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <index_t BlockSize,
typename srcDataType,
typename dstDataType,
typename compType,
typename src2dDescType,
typename dst1dDescType,
ReduceTensorOp_t op,
NanPropagation_t nanPropaOpt,
ReduceTensorIndices_t reduceIndicesOpt,
bool isFirstCall,
bool isLastCall,
index_t GredAccessesPerThreadInWarp>
struct GridwiseReduction_xy_to_x_direct_warpwise
{
using opReduce = typename reduce_binary_operator<compType, op>::opType;
using preUnaryOpType =
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::preUnaryOp;
using posUnaryOpType =
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::posUnaryOp;
static constexpr auto I0 = Number<0>{};
template <int RunId>
__device__ static void Run(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global);
template <>
__device__ static void Run<1>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)ws_indices_global;
(void)indices_global;
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredAccessesPerThreadInWarp, true>
in_thread_buf;
using warpwise_reduce =
WarpReduce<decltype(in_thread_buf), BlockSize, opReduce, nanPropaOpt>;
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
accuValue_buf(I0) = zeroVal;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
const posUnaryOpType posUnaryOp(divider);
using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>;
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<GredAccessesPerThreadInWarp>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
index_t warp_global_1d_id = thread_global_1d_id / warpSize;
index_t thread_inwarp_id = thread_global_1d_id % warpSize;
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
compType,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc,
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
constexpr auto in_thread_copy_step =
make_multi_index(0, warpSize * GredAccessesPerThreadInWarp);
for(index_t reducedLength = 0; reducedLength < toReduceLength;
reducedLength += warpSize * GredAccessesPerThreadInWarp)
{
threadwise_src_load.Run(
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
// do element-wise pre-reduction operation
warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
// do the warp-wise reduction on data of all thread buffers
warpwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0));
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
}
accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]);
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
// The first thread in the warp stores the reduced result to the global location
// representing the Warp
if(thread_inwarp_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf(I0) * beta;
}
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}
};
template <>
__device__ static void Run<2>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)ws_indices_global;
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
indices_global, dst1dDesc.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredAccessesPerThreadInWarp, true>
in_thread_buf;
using warpwise_reduce =
WarpReduce<decltype(in_thread_buf), BlockSize, opReduce, nanPropaOpt>;
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
accuValue_buf(I0) = zeroVal;
accuIndex_buf(I0) = 0;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>;
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<GredAccessesPerThreadInWarp>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
index_t warp_global_1d_id = thread_global_1d_id / warpSize;
index_t thread_inwarp_id = thread_global_1d_id % warpSize;
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
compType,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc,
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
constexpr auto in_thread_copy_step =
make_multi_index(0, warpSize * GredAccessesPerThreadInWarp);
index_t indexOffset = 0;
for(index_t reducedLength = 0; reducedLength < toReduceLength;
reducedLength += warpSize * GredAccessesPerThreadInWarp)
{
threadwise_src_load.Run(
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
// do the warp-wise reduction on data of all thread buffers
warpwise_reduce::Reduce2(
in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexOffset);
indexOffset += warpSize * GredAccessesPerThreadInWarp;
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
// The first thread in the warp stores the reduced result to the global location
// representing the Warp
if(thread_inwarp_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(dst1dDesc,
dst_global_val_buf,
ReducedDataDesc,
make_tuple(I0),
priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<int,
int,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
};
template <>
__device__ static void Run<3>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
srcDataType alpha,
const srcDataType* const __restrict__ ws_values_global,
dstDataType beta,
dstDataType* const __restrict__ p_dst_global,
const int* const __restrict__ ws_indices_global,
int* const __restrict__ indices_global)
{
(void)origReduceLen;
const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
src2dDesc.GetElementSpaceSize(),
type_convert<srcDataType>{}(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, src2dDesc.GetElementSpaceSize());
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_dst_global, dst1dDesc.GetElementSpaceSize());
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
indices_global, dst1dDesc.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredAccessesPerThreadInWarp, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, GredAccessesPerThreadInWarp, true>
in_thread_idx_buf;
using warpwise_reduce = WarpReduceWithIndicesInput<decltype(in_thread_val_buf),
decltype(in_thread_idx_buf),
BlockSize,
opReduce,
nanPropaOpt>;
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
accuValue_buf(I0) = zeroVal;
accuIndex_buf(I0) = 0;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>;
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<GredAccessesPerThreadInWarp>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
index_t warp_global_1d_id = thread_global_1d_id / warpSize;
index_t thread_inwarp_id = thread_global_1d_id % warpSize;
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
compType,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc,
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<int,
int,
src2dDescType,
decltype(ThreadBufferDesc),
ThreadBufferLengths,
Sequence<0, 1>,
1,
1,
1,
false>(
src2dDesc,
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
constexpr auto in_thread_copy_step =
make_multi_index(0, warpSize * GredAccessesPerThreadInWarp);
for(index_t reducedLength = 0; reducedLength < toReduceLength;
reducedLength += warpSize * GredAccessesPerThreadInWarp)
{
threadwise_src_val_load.Run(src2dDesc,
src_global_val_buf,
ThreadBufferDesc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(src2dDesc,
src_global_idx_buf,
ThreadBufferDesc,
make_tuple(I0, I0),
in_thread_idx_buf);
// do the warp-wise reduction on data of all thread buffers
warpwise_reduce::Reduce(
in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0));
threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
// The first thread in the warp stores the reduced result to the global location
// representing the Warp
if(thread_inwarp_id == 0)
{
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<dstDataType,
dstDataType,
dst1dDescType,
decltype(ReducedDataDesc),
Sequence<1>,
Sequence<0>,
0,
1,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
threadwise_dst_load.Run(dst1dDesc,
dst_global_val_buf,
ReducedDataDesc,
make_tuple(I0),
priorDstValue_buf);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<int,
int,
decltype(ReducedDataDesc),
dst1dDescType,
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(dst1dDesc,
make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
};
};
} // namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_blockwise.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
namespace ck {
template <index_t BlockSize,
typename srcDataType,
typename dstDataType, // not used together with the beta input
typename compType,
typename src2dDescType,
typename dst1dDescType,
ReduceTensorOp_t op,
NanPropagation_t nanPropaOpt,
ReduceTensorIndices_t reduceIndicesOpt,
index_t GredAccessesPerThreadInBlock>
struct GridwiseReduction_xy_to_x_multiblock
{
using opReduce = typename reduce_binary_operator<compType, op>::opType;
using preUnaryOpType = typename reduce_unary_operator<compType, op, true, false>::preUnaryOp;
using posUnaryOpType = typename reduce_unary_operator<compType, op, true, false>::posUnaryOp;
static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed(
make_tuple(Number<GredAccessesPerThreadInBlock>{}, Number<BlockSize>{}));
using blockwise_reduce =
BlockwiseReduction_2d_block_buffer<decltype(buffer2dDesc), true, opReduce, nanPropaOpt>;
static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize();
static constexpr auto I0 = Number<0>{};
template <int RunId>
__device__ static void Run(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
int BlkGroupSize,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
srcDataType* const __restrict__ ws_values_global,
int* const __restrict__ ws_indices_global);
template <>
__device__ static void Run<1>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
int BlkGroupSize,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
srcDataType* const __restrict__ ws_values_global,
int* const __restrict__ ws_indices_global)
{
(void)ws_indices_global;
(void)alpha; // unused
(void)beta; // unused
const auto zeroVal = opReduce::GetReductionZeroVal();
// LDS
__shared__ compType p_in_block_buffer[BlockBufferSize];
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
auto in_block_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
accuValue_buf(I0) = zeroVal;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / BlkGroupSize;
const index_t block_local_id = block_global_id % BlkGroupSize;
const index_t reduceSizePerBlock =
(((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) /
BlockBufferSize) *
BlockBufferSize;
constexpr auto in_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<BlockSize * GredAccessesPerThreadInBlock>{}));
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
using ThreadClusterLengths = Sequence<1, BlockSize>;
auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, BlockBufferSize>,
ThreadSliceLengths,
ThreadClusterLengths,
Sequence<0, 1>,
srcDataType,
compType,
src2dDescType,
decltype(in_block_desc),
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
1,
1,
1,
1,
false,
true>(
src2dDesc,
make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock),
in_block_desc,
make_multi_index(0, 0));
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize;
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
reducedBlocks += GredAccessesPerThreadInBlock)
{
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
blockwise_src_load.RunWrite(in_block_desc, in_block_buf);
__syncthreads();
// do element-wise pre-reduction operation
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf);
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
? GredAccessesPerThreadInBlock
: toReduceBlocks - reducedBlocks;
blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0));
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
const auto workspace_desc =
make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize));
// The first thread in the block stores the reduced result to the global location
// representing the block
if(thread_local_id == 0)
{
auto threadwise_workspace_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
srcDataType,
decltype(ReducedDataDesc),
decltype(workspace_desc),
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(workspace_desc,
make_multi_index(block_global_id));
threadwise_workspace_store.Run(ReducedDataDesc,
make_tuple(I0),
accuValue_buf,
workspace_desc,
workspace_global_buf);
}
};
template <>
__device__ static void Run<2>(const src2dDescType& src2dDesc,
const dst1dDescType& dst1dDesc,
int origReduceLen,
int BlkGroupSize,
srcDataType alpha,
const srcDataType* const __restrict__ p_src_global,
dstDataType beta,
srcDataType* const __restrict__ ws_values_global,
int* const __restrict__ ws_indices_global)
{
(void)alpha; // unused
(void)beta; // unused
const auto zeroVal = opReduce::GetReductionZeroVal();
// LDS
__shared__ compType p_in_block_values_buffer[BlockBufferSize];
__shared__ int p_in_block_indices_buffer[BlockBufferSize];
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
ws_indices_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
auto in_block_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_values_buffer, BlockBufferSize);
auto in_block_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_in_block_indices_buffer, BlockBufferSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
accuValue_buf(I0) = zeroVal;
accuIndex_buf(I0) = 0;
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
const int divider = origReduceLen;
const preUnaryOpType preUnaryOp(divider);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / BlkGroupSize;
const index_t block_local_id = block_global_id % BlkGroupSize;
const index_t reduceSizePerBlock =
(((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) /
BlockBufferSize) *
BlockBufferSize;
constexpr auto in_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<BlockSize * GredAccessesPerThreadInBlock>{}));
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
using ThreadClusterLengths = Sequence<1, BlockSize>;
auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<1, BlockBufferSize>,
ThreadSliceLengths,
ThreadClusterLengths,
Sequence<0, 1>,
srcDataType,
compType,
src2dDescType,
decltype(in_block_desc),
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
1,
1,
1,
1,
false,
true>(
src2dDesc,
make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock),
in_block_desc,
make_multi_index(0, 0));
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize;
int indexOffset = block_local_id * reduceSizePerBlock;
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
reducedBlocks += GredAccessesPerThreadInBlock)
{
blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset);
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf);
__syncthreads();
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf);
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
? GredAccessesPerThreadInBlock
: toReduceBlocks - reducedBlocks;
blockwise_reduce::Reduce2(in_block_val_buf,
in_block_idx_buf,
BlocksInOneOp,
accuValue_buf(I0),
accuIndex_buf(I0));
indexOffset += BlockBufferSize;
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
}
constexpr auto ReducedDataDesc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
const auto workspace_desc =
make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize));
// The first thread in the block stores the reduced result to the global location
// representing the block
if(thread_local_id == 0)
{
auto threadwise_workspace_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
srcDataType,
decltype(ReducedDataDesc),
decltype(workspace_desc),
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(workspace_desc,
make_multi_index(block_global_id));
auto threadwise_workspace_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<int,
int,
decltype(ReducedDataDesc),
decltype(workspace_desc),
Sequence<1>,
Sequence<0>,
0,
1,
InMemoryDataOperationEnum_t::Set,
1,
true>(workspace_desc,
make_multi_index(block_global_id));
threadwise_workspace_val_store.Run(ReducedDataDesc,
make_tuple(I0),
accuValue_buf,
workspace_desc,
workspace_global_val_buf);
threadwise_workspace_idx_store.Run(ReducedDataDesc,
make_tuple(I0),
accuIndex_buf,
workspace_desc,
workspace_global_idx_buf);
}
};
};
} // namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace ck {
template <typename buffer2dDescType,
bool blockIsOneRow,
typename opReduce,
NanPropagation_t nanPropaOpt>
struct BlockwiseReduction_2d_block_buffer
{
using compType = typename opReduce::dataType;
static constexpr auto buffer2dDesc = buffer2dDescType{};
static constexpr index_t BlockSize =
blockIsOneRow ? buffer2dDesc.GetLength(Number<1>{}) : buffer2dDesc.GetLength(Number<0>{});
static constexpr index_t NumBlocks =
blockIsOneRow ? buffer2dDesc.GetLength(Number<0>{}) : buffer2dDesc.GetLength(Number<1>{});
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
// This interface does not accumulate on indices
template <typename BufferType>
__device__ static void
Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData)
{
const index_t thread_local_id = get_thread_local_1d_id();
compType lAccuData = opReduce::GetReductionZeroVal();
index_t offset;
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
{
offset = blockIsOneRow
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
compType opData = type_convert<compType>{}(block_buffer[offset]);
binop::calculate(lAccuData, opData);
}
offset = blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
block_buffer(offset) = lAccuData;
__syncthreads();
for(index_t indOffset = BlockSize / 2; indOffset > 0; indOffset /= 2)
{
if(thread_local_id < indOffset)
{
index_t offset1 =
blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
index_t offset2 =
blockIsOneRow
? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset))
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
compType opData1 = type_convert<compType>{}(block_buffer[offset1]);
compType opData2 = type_convert<compType>{}(block_buffer[offset2]);
binop::calculate(opData1, opData2);
block_buffer(offset1) = type_convert<compType>{}(opData1);
}
__syncthreads();
}
if(thread_local_id == 0)
{
compType tmpVal = type_convert<compType>{}(block_buffer[0]);
binop::calculate(accuData, tmpVal);
}
};
// This interface accumulates on both data values and indices
template <typename BufferType, typename IdxBufferType>
__device__ static void Reduce2(BufferType& block_buffer,
IdxBufferType& block_indices_buffer,
index_t toReduceBlocks,
compType& accuData,
int& accuIndex)
{
const index_t thread_local_id = get_thread_local_1d_id();
compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0;
if constexpr(blockIsOneRow)
{
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
{
for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2)
{
if(thread_local_id % (indOffset * 2) == 0)
{
index_t offset1 =
buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id));
index_t offset2 = buffer2dDesc.CalculateOffset(
make_tuple(otherDimInd, thread_local_id + indOffset));
compType currVal1 = type_convert<compType>{}(block_buffer[offset1]);
compType currVal2 = type_convert<compType>{}(block_buffer[offset2]);
int currIndex1 = block_indices_buffer[offset1];
int currIndex2 = block_indices_buffer[offset2];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
block_buffer(offset1) = type_convert<compType>{}(currVal1);
block_indices_buffer(offset1) = currIndex1;
}
__syncthreads();
}
}
if(thread_local_id == 0)
{
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
{
index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0));
compType tmpVal = type_convert<compType>{}(block_buffer[offset]);
int tmpIndex = block_indices_buffer[offset];
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
}
binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex);
}
}
else
{
index_t offset;
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
{
offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
compType currVal = type_convert<compType>{}(block_buffer[offset]);
int currIndex = block_indices_buffer[offset];
binop::calculate(lAccuData, currVal, lAccuIndex, currIndex);
}
offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
block_buffer(offset) = lAccuData;
block_indices_buffer(offset) = lAccuIndex;
__syncthreads();
for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2)
{
if(thread_local_id % (indOffset * 2) == 0)
{
index_t offset1 = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
index_t offset2 =
buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
compType currVal1 = type_convert<compType>{}(block_buffer[offset1]);
compType currVal2 = type_convert<compType>{}(block_buffer[offset2]);
int currIndex1 = block_indices_buffer[offset1];
int currIndex2 = block_indices_buffer[offset2];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
block_buffer(offset1) = type_convert<compType>{}(currVal1);
block_indices_buffer(offset1) = currIndex1;
}
__syncthreads();
}
if(thread_local_id == 0)
{
compType tmpVal = type_convert<compType>{}(block_buffer[0]);
int tmpIndex = block_indices_buffer[0];
binop::calculate(accuData, tmpVal, accuIndex, tmpIndex);
}
}
};
template <typename BufferType>
__device__ static void set_buffer_value(BufferType& block_buffer, compType value)
{
index_t thread_id = get_thread_local_1d_id();
for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++)
{
index_t offset = blockIsOneRow
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd));
block_buffer(offset) = value;
__syncthreads();
}
};
// Initialize the block-wise indices buffer, the index for each element in the block-wise data
// buffer
// is calculated according to its position in the buffer and the global starting index
template <typename IdxBufferType>
__device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart)
{
index_t thread_id = get_thread_local_1d_id();
for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++)
{
index_t offset = blockIsOneRow
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd));
block_indices_buffer(offset) = offset + indexStart;
__syncthreads();
}
};
// Execute unary operation on the block buffer elements
template <typename unary_op_type, typename BufferType>
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& block_buffer)
{
index_t thread_id = get_thread_local_1d_id();
for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++)
{
index_t offset = blockIsOneRow
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id))
: buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd));
block_buffer(offset) = unary_op(block_buffer[offset]);
__syncthreads();
}
};
};
}; // end of namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace ck {
template <typename BufferType, typename opReduce, NanPropagation_t nanPropaOpt>
struct ThreadReduce
{
using compType = typename opReduce::dataType;
static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!");
static_assert(
std::is_same<typename BufferType::type, compType>::value,
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!");
static constexpr index_t ThreadBufferLen = BufferType::Size();
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
// This interface does not accumulate on indices
__device__ static void Reduce(const BufferType& thread_buffer, compType& accuData)
{
static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { binop::calculate(accuData, thread_buffer[I]); });
};
// This interface accumulates on both data values and indices and
// is called by Direct_ThreadWise reduction method at first-time reduction
__device__ static void
Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart)
{
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
int currIndex = I + indexStart;
binop::calculate(accuData, thread_buffer[I], accuIndex, currIndex);
});
};
// Set the elements in the per-thread buffer to a specific value
// cppcheck-suppress constParameter
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
{
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
};
// Execute unary operation on the per-thread buffer elements
template <typename unary_op_type>
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
{
static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
};
};
template <typename BufferType,
typename IdxBufferType,
typename opReduce,
NanPropagation_t nanPropaOpt>
struct ThreadReduceWithIndicesInput
{
using compType = typename opReduce::dataType;
static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!");
static_assert(IdxBufferType::IsStaticBuffer(),
"Thread-wise reduction needs use StaticBuffer for indices!");
static_assert(
std::is_same<typename BufferType::type, compType>::value,
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!");
static_assert(std::is_same<typename IdxBufferType::type, index_t>::value,
"Indices type of StaticBuffer for Thread-wise reduction should be index_t!");
static_assert(BufferType::Size() == IdxBufferType::Size(),
"StaticBuffers for data and indices should have the same sizes!");
static constexpr index_t ThreadBufferLen = BufferType::Size();
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
// This interface accumulates on both data values and indices and
// is called by Direct_ThreadWise reduction method at second-time reduction
__device__ static void Reduce(const BufferType& thread_buffer,
const IdxBufferType& thread_indices_buffer,
compType& accuData,
int& accuIndex)
{
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
binop::calculate(accuData, thread_buffer[I], accuIndex, thread_indices_buffer[I]);
});
};
// Set the elements in the per-thread buffer to a specific value
// cppcheck-suppress constParameter
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
{
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
};
// Execute unary operation on the per-thread buffer elements
template <typename unary_op_type>
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
{
static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
};
};
}; // end of namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
#define CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace ck {
template <typename BufferType, index_t BlockSize, typename opReduce, NanPropagation_t nanPropaOpt>
struct WarpReduce
{
using compType = typename opReduce::dataType;
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
static_assert(BufferType::IsStaticBuffer(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer!");
static_assert(std::is_same<typename BufferType::type, compType>::value,
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
"the compType!");
static constexpr index_t ThreadBufferLen = BufferType::Size();
static constexpr bool have_builtin_shuffle =
std::is_same<compType, float>::value || std::is_same<compType, double>::value;
// This interface does not accumulate on indices
__device__ static void Reduce(const BufferType& thread_buffer, compType& accuData)
{
if constexpr(have_builtin_shuffle)
ReduceImpl1(thread_buffer, accuData);
else
ReduceImpl2(thread_buffer, accuData);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData)
{
compType lAccuData = opReduce::GetReductionZeroVal();
static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
// synchronize among all threads in this warp
__all(1);
for(index_t stride = warpSize / 2; stride > 0; stride /= 2)
{
compType tmpVal = __shfl_down(lAccuData, stride, warpSize);
binop::calculate(lAccuData, tmpVal);
__all(1);
}
binop::calculate(accuData, lAccuData);
};
// This interface implementation does not use HIP built-in device shuffling functions
// since for fp16, built-in shuffling functions is not provided by HIP
__device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData)
{
compType lAccuData = opReduce::GetReductionZeroVal();
static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
__syncthreads();
index_t thread_id = get_thread_local_1d_id();
index_t warpId = thread_id / warpSize;
index_t thread_inwarp_id = thread_id % warpSize;
__shared__ compType shuffle_buffer[BlockSize];
compType* myBuffer = &shuffle_buffer[warpId * warpSize];
myBuffer[thread_inwarp_id] = lAccuData;
__syncthreads();
for(index_t stride = warpSize / 2; stride > 0; stride /= 2)
{
if(thread_inwarp_id < stride)
{
compType currVal1 = myBuffer[thread_inwarp_id];
compType currVal2 = myBuffer[thread_inwarp_id + stride];
binop::calculate(currVal1, currVal2);
myBuffer[thread_inwarp_id] = currVal1;
}
__syncthreads();
}
if(thread_inwarp_id == 0)
binop::calculate(accuData, myBuffer[0]);
};
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
// reduction method at first-time reduction
__device__ static void
Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart)
{
if constexpr(have_builtin_shuffle)
Reduce2Impl1(thread_buffer, accuData, accuIndex, indexStart);
else
Reduce2Impl2(thread_buffer, accuData, accuIndex, indexStart);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__ static void Reduce2Impl1(const BufferType& thread_buffer,
compType& accuData,
int& accuIndex,
int indexStart)
{
compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0;
index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize;
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart;
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex);
});
// synchronize among all threads in this warp
__all(1);
for(index_t stride = 1; stride < warpSize; stride *= 2)
{
compType tmpVal = __shfl_down(lAccuData, stride, warpSize);
int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize);
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
__all(1);
}
if(thread_inwarp_id == 0)
binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex);
};
// This interface implementation does not use HIP built-in device shuffling functions since for
// fp16, built-in shuffling functions is not provided by HIP
__device__ static void Reduce2Impl2(const BufferType& thread_buffer,
compType& accuData,
int& accuIndex,
int indexStart)
{
compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0;
index_t thread_id = get_thread_local_1d_id();
index_t warpId = thread_id / warpSize;
index_t thread_inwarp_id = thread_id % warpSize;
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart;
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex);
});
__shared__ compType shuffle_data_buffer[BlockSize];
__shared__ int shuffle_indices_buffer[BlockSize];
compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize];
int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize];
myDataBuffer[thread_inwarp_id] = lAccuData;
myIndicesBuffer[thread_inwarp_id] = lAccuIndex;
__syncthreads();
for(index_t stride = 1; stride < warpSize; stride *= 2)
{
compType currVal1 = myDataBuffer[thread_inwarp_id];
compType currVal2 = myDataBuffer[thread_inwarp_id + stride];
int currIndex1 = myIndicesBuffer[thread_inwarp_id];
int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
myDataBuffer[thread_inwarp_id] = currVal1;
myIndicesBuffer[thread_inwarp_id] = currIndex1;
__syncthreads();
}
if(thread_inwarp_id == 0)
binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]);
};
// cppcheck-suppress constParameter
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
{
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
__all(1);
};
// Execute unary operation on the per-thread buffer elements
template <typename unary_op_type>
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
{
static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
__all(1);
};
};
template <typename BufferType,
typename IdxBufferType,
index_t BlockSize,
typename opReduce,
NanPropagation_t nanPropaOpt>
struct WarpReduceWithIndicesInput
{
using compType = typename opReduce::dataType;
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
static_assert(BufferType::IsStaticBuffer(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer!");
static_assert(IdxBufferType::IsStaticBuffer(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer for indices!");
static_assert(std::is_same<typename BufferType::type, compType>::value,
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
"the compType!");
static_assert(
std::is_same<typename IdxBufferType::type, index_t>::value,
"Indices type per-thread of StaticBuffer for WarpWise reduction should be index_t!");
static_assert(BufferType::Size() == IdxBufferType::Size(),
"StaticBuffers for data and indices should have the same sizes!");
static constexpr index_t ThreadBufferLen = BufferType::Size();
static constexpr bool have_builtin_shuffle =
std::is_same<compType, float>::value || std::is_same<compType, double>::value;
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
// reduction method at second-time reduction
__device__ static void Reduce(const BufferType& thread_buffer,
const IdxBufferType& thread_indices_buffer,
compType& accuData,
int& accuIndex)
{
if constexpr(have_builtin_shuffle)
ReduceImpl1(thread_buffer, thread_indices_buffer, accuData, accuIndex);
else
ReduceImpl2(thread_buffer, thread_indices_buffer, accuData, accuIndex);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__ static void ReduceImpl1(const BufferType& thread_buffer,
const IdxBufferType& thread_indices_buffer,
compType& accuData,
int& accuIndex)
{
compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0;
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]);
});
// synchronize among all threads in this warp
__all(1);
for(index_t stride = 1; stride < warpSize; stride *= 2)
{
compType tmpVal = __shfl_down(lAccuData, stride, warpSize);
int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize);
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
__all(1);
}
binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex);
};
// This interface implementation does not use HIP built-in device shuffling functions
// since for fp16, built-in shuffling functions is not provided by HIP
__device__ static void ReduceImpl2(const BufferType& thread_buffer,
const IdxBufferType& thread_indices_buffer,
compType& accuData,
int& accuIndex)
{
compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0;
index_t thread_id = get_thread_local_1d_id();
index_t warpId = thread_id / warpSize;
index_t thread_inwarp_id = thread_id % warpSize;
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]);
});
__shared__ compType shuffle_data_buffer[BlockSize];
__shared__ int shuffle_indices_buffer[BlockSize];
compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize];
int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize];
myDataBuffer[thread_inwarp_id] = lAccuData;
myIndicesBuffer[thread_inwarp_id] = lAccuIndex;
__syncthreads();
for(index_t stride = 1; stride < warpSize; stride *= 2)
{
compType currVal1 = myDataBuffer[thread_inwarp_id];
compType currVal2 = myDataBuffer[thread_inwarp_id + stride];
int currIndex1 = myIndicesBuffer[thread_inwarp_id];
int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride];
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
myDataBuffer[thread_inwarp_id] = currVal1;
myIndicesBuffer[thread_inwarp_id] = currIndex1;
__syncthreads();
}
if(thread_inwarp_id == 0)
binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]);
};
// cppcheck-suppress constParameter
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
{
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
__all(1);
};
// Execute unary operation on the per-thread buffer elements
template <typename unary_op_type>
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
{
static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
__all(1);
};
};
}; // end of namespace ck
#endif
...@@ -397,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -397,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v2
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
} }
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{ {
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
...@@ -713,9 +713,6 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -713,9 +713,6 @@ struct ThreadwiseTensorSliceTransfer_v3
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
{ {
// TODO: fix this
static_assert(is_same<SrcData, DstData>::value,
"wrong! current implementation assume SrcData and DstData are same type");
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -985,7 +982,8 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -985,7 +982,8 @@ struct ThreadwiseTensorSliceTransfer_v3
constexpr index_t buffer_offset = constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_tmp_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}]; dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]);
}); });
using dst_vector_t = typename decltype(dst_tmp_vector)::type; using dst_vector_t = typename decltype(dst_tmp_vector)::type;
......
...@@ -44,15 +44,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32> ...@@ -44,15 +44,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
static constexpr index_t k_per_blk = 1; static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -71,15 +66,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32> ...@@ -71,15 +66,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
static constexpr index_t k_per_blk = 1; static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -98,15 +88,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32> ...@@ -98,15 +88,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
static constexpr index_t k_per_blk = 1; static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -125,15 +110,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32> ...@@ -125,15 +110,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
static constexpr index_t k_per_blk = 1; static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -153,15 +133,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32> ...@@ -153,15 +133,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
static constexpr index_t k_per_blk = 1; static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -180,15 +155,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16> ...@@ -180,15 +155,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
static constexpr index_t k_per_blk = 4; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -207,15 +177,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16> ...@@ -207,15 +177,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
static constexpr index_t k_per_blk = 4; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -234,15 +199,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16> ...@@ -234,15 +199,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
static constexpr index_t k_per_blk = 4; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true; static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -261,15 +221,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16> ...@@ -261,15 +221,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
static constexpr index_t k_per_blk = 4; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -288,15 +243,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16> ...@@ -288,15 +243,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
static constexpr index_t k_per_blk = 4; static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false; static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops, template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
...@@ -732,7 +682,7 @@ struct XdlopsGemm ...@@ -732,7 +682,7 @@ struct XdlopsGemm
return MPerXdlops * NPerXdlops / mfma_instr.wave_size; return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
} }
template <index_t c_offset, class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
...@@ -740,8 +690,7 @@ struct XdlopsGemm ...@@ -740,8 +690,7 @@ struct XdlopsGemm
"base base_type must be float, half, ushort!"); "base base_type must be float, half, ushort!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops, c_offset>( mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
p_a_wave[k], p_b_wave[k], p_c_thread);
}); });
} }
...@@ -819,8 +768,9 @@ struct XdlopsGemm ...@@ -819,8 +768,9 @@ struct XdlopsGemm
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto KPerThread = mfma.GetKPerThread(); static constexpr auto K1PerXdlops = mfma.GetKPerThread();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
{ {
......
...@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( ...@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> struct intrin_mfma_f32_32x32x1f32<64, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a, reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_b, reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> struct intrin_mfma_f32_32x32x1f32<32, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2f32; struct intrin_mfma_f32_32x32x2f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> struct intrin_mfma_f32_32x32x2f32<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
llvm_intrin_amdgcn_mfma_f32_32x32x2f32( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x4f32; struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> struct intrin_mfma_f32_16x16x4f32<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
llvm_intrin_amdgcn_mfma_f32_16x16x4f32( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x1f32; struct intrin_mfma_f32_16x16x1f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> struct intrin_mfma_f32_16x16x1f32<16, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
llvm_intrin_amdgcn_mfma_f32_16x16x1f32( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x1f32; struct intrin_mfma_f32_4x4x1f32;
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> struct intrin_mfma_f32_4x4x1f32<4, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> struct intrin_mfma_f32_4x4x1f32<8, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a, reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_b, reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4f16; struct intrin_mfma_f32_32x32x4f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> struct intrin_mfma_f32_32x32x4f16<64, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a, reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_b, reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> struct intrin_mfma_f32_32x32x4f16<32, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8f16; struct intrin_mfma_f32_32x32x8f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> struct intrin_mfma_f32_32x32x8f16<32, 32>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
llvm_intrin_amdgcn_mfma_f32_32x32x8f16( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x16f16; struct intrin_mfma_f32_16x16x16f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> struct intrin_mfma_f32_16x16x16f16<16, 16>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
llvm_intrin_amdgcn_mfma_f32_16x16x16f16( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x4f16; struct intrin_mfma_f32_16x16x4f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> struct intrin_mfma_f32_16x16x4f16<16, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
llvm_intrin_amdgcn_mfma_f32_16x16x4f16( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x4f16; struct intrin_mfma_f32_4x4x4f16;
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> struct intrin_mfma_f32_4x4x4f16<4, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
} }
}; };
template <index_t COffset> template <>
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> struct intrin_mfma_f32_4x4x4f16<8, 64>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_a, reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_b, reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
} }
}; };
...@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave> ...@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b, const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c); c_vec16_1_t::VecType reg_c);
template <> template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b, const ushort2_t* reg_b,
......
...@@ -90,8 +90,8 @@ ...@@ -90,8 +90,8 @@
#endif #endif
// pass tensor descriptor by value or void* // pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
// merge transformation use magic number division // merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment