"vscode:/vscode.git/clone" did not exist on "f503a848c2aedc51250cfee143cc473d469551de"
Commit 56863b9a authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp8 support

parents 54df59bf d4c84256
...@@ -17,6 +17,8 @@ template <index_t InOutRank, ...@@ -17,6 +17,8 @@ template <index_t InOutRank,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
typename InLayout,
typename OutLayout,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
bool OutputIndex> bool OutputIndex>
struct DevicePoolFwd : public BaseOperator struct DevicePoolFwd : public BaseOperator
...@@ -25,13 +27,14 @@ struct DevicePoolFwd : public BaseOperator ...@@ -25,13 +27,14 @@ struct DevicePoolFwd : public BaseOperator
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_n_c_wis_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_xs_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_n_c_wos_lengths,
std::vector<ck::index_t> input_stride, std::vector<ck::index_t> input_n_c_wis_stride,
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_n_c_wis_stride,
std::vector<ck::index_t> indices_stride, std::vector<ck::index_t> indices_n_c_wis_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_xs_strides,
std::vector<ck::index_t> window_xs_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) = 0; std::vector<ck::index_t> pooling_dims) = 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "pool_fwd_instance_common.hpp" #pragma once
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; enum struct GemmDlAlgorithm
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( Default, // Uses DOT vector instructions
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{}); Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS
} };
} // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// In and Din = [N, C, Di, Hi, Wi]
// Out and Dout = [N, C, Do, Ho, Wo]
// Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W
template <typename DOutDataType,
typename DInDataType,
typename ComputeDataType,
ck::index_t BlockSize,
ck::index_t MThreadClusterSize,
ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DeviceAvgPool3dBwd_NDHWC_NDHWC : public DeviceAvgPoolBwd<3,
DOutDataType,
DInDataType,
tensor_layout::convolution::NDHWC,
tensor_layout::convolution::NDHWC>
{
static constexpr ck::index_t NDimSpatial = 3;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto
Make3DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
const std::vector<ck::index_t>& din_n_c_wos_length,
const std::vector<ck::index_t>& dout_n_c_wos_strides,
const std::vector<ck::index_t>& din_n_c_wos_strides,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads,
const std::vector<ck::index_t>& tildes)
{
index_t i_ztilde = tildes[0];
index_t i_ytilde = tildes[1];
index_t i_xtilde = tildes[2];
const index_t N = dout_n_c_wos_lengths[0];
const index_t C = dout_n_c_wos_lengths[1];
const index_t Di = din_n_c_wos_length[2];
const index_t Hi = din_n_c_wos_length[3];
const index_t Wi = din_n_c_wos_length[4];
const index_t Do = dout_n_c_wos_lengths[2];
const index_t Ho = dout_n_c_wos_lengths[3];
const index_t Wo = dout_n_c_wos_lengths[4];
const index_t Z = window_lengths[0];
const index_t Y = window_lengths[1];
const index_t X = window_lengths[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t ConvStrideD = window_strides[0];
const index_t ConvStrideH = window_strides[1];
const index_t ConvStrideW = window_strides[2];
const index_t ConvDilationD = window_dilations[0];
const index_t ConvDilationH = window_dilations[1];
const index_t ConvDilationW = window_dilations[2];
const auto out_n_do_ho_wo_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C),
make_tuple(dout_n_c_wos_strides[0],
dout_n_c_wos_strides[2],
dout_n_c_wos_strides[3],
dout_n_c_wos_strides[4],
dout_n_c_wos_strides[1]));
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on Tildes that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd =
math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// ReduceK is different for each Reduce
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// Problem size of reduction kernel
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// Out[ReduceM, ReduceK]
const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc =
transform_tensor_descriptor(
out_n_dop_hop_wop_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
out_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// In[ReduceM]
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
make_tuple(din_n_c_wos_strides[0],
din_n_c_wos_strides[2],
din_n_c_wos_strides[3],
din_n_c_wos_strides[4],
din_n_c_wos_strides[1]));
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(XTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<>{},
Sequence<3>{},
Sequence<4>{}));
const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto in_grid_desc_reducem =
transform_tensor_descriptor(in_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
}
using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0}));
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
// FIXME
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
using PassThrough = tensor_operation::element_wise::PassThrough;
using Div = tensor_operation::element_wise::UnaryDivide;
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
reduce::Add,
PassThrough,
Div,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
OutSrcInDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
struct Argument : public BaseArgument
{
Argument(const DOutDataType* p_dout,
DInDataType* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout},
p_din_grid_{p_din},
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
din_n_c_wos_length_{din_n_c_wos_length},
dout_n_c_wos_strides_{dout_n_c_wos_strides},
din_n_c_wos_strides_{din_n_c_wos_strides},
num_reduce_{1},
div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
{
std::vector<ck::index_t> Tildes(NDimSpatial);
for(int i = 0; i < NDimSpatial; ++i)
{
int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
Tildes[i] = window_strides[i] / GcdStrideDilation;
num_reduce_ *= Tildes[i];
}
for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde)
{
for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde)
{
// check slice is valid
const auto ZDotSlice =
math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]);
const auto YDotSlice =
math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]);
const auto XDotSlice =
math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]);
if(ZDotSlice * YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto dout_din_grid_desc =
Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads,
{i_ztilde, i_ytilde, i_xtilde});
dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
}
}
}
}
const DOutDataType* p_dout_grid_;
DInDataType* p_din_grid_;
std::vector<ck::index_t> dout_n_c_wos_lengths_;
std::vector<ck::index_t> din_n_c_wos_length_;
std::vector<ck::index_t> dout_n_c_wos_strides_;
std::vector<ck::index_t> din_n_c_wos_strides_;
int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
std::vector<DinGridDesc_M> din_grid_desc_m_container_;
Div div_element_op_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(index_t i = 0; i < arg.num_reduce_; i++)
{
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
false,
false,
false, // don't have index input
DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
PassThrough,
Div>;
ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
const index_t grid_size = (M / M_BlockTileSize);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.dout_grid_desc_m_k_container_[i],
arg.din_grid_desc_m_container_[i],
PassThrough{},
arg.div_element_op_,
float(1),
arg.p_dout_grid_,
nullptr,
float(0),
arg.p_din_grid_,
nullptr);
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
constexpr index_t Rank = NDimSpatial + 2;
int doutFastestDim = -1;
int dinFastestDim = -1;
for(int i = 0; i < Rank; ++i)
{
if(arg.dout_n_c_wos_strides_[i] == 1)
doutFastestDim = i;
if(arg.din_n_c_wos_strides_[i] == 1)
dinFastestDim = i;
}
if(doutFastestDim == -1 || dinFastestDim == -1)
{
if constexpr(InSrcOutDstVectorSize != 1)
return false;
}
else
{
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
return false;
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
return false;
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override
{
constexpr index_t Rank = NDimSpatial + 2;
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
throw std::runtime_error("dimension is incorrect");
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<DInDataType*>(p_din),
dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceAvgPool3dBwd<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, true>; ADataType,
BDataType,
CDataType,
true>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, false>; ADataType,
BDataType,
CDataType,
false>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -59,6 +60,7 @@ template < ...@@ -59,6 +60,7 @@ template <
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default,
enable_if_t< enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
...@@ -236,7 +238,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -236,7 +238,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector,
GemmDlAlg>;
using AGridDesc_K0_M0_M1_K1 = using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
...@@ -372,7 +375,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -372,7 +375,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
true>; true,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -398,7 +402,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -398,7 +402,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
false>; false,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -424,7 +429,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -424,7 +429,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
true>; true,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -450,7 +456,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -450,7 +456,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
false>; false,
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -485,6 +492,16 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -485,6 +492,16 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
if(ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
return false;
}
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1102")
...@@ -492,10 +509,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -492,10 +509,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
} }
else return false;
{
return false;
}
} }
// polymorphic // polymorphic
...@@ -572,7 +586,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -572,7 +586,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
// polymorphic // polymorphic
std::string GetTypeString() const override virtual std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDlDpp8 : public DeviceGemmDl<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
GemmDlAlgorithm::Dpp8>
{
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDlDpp8"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -65,7 +65,8 @@ template <typename ALayout, ...@@ -65,7 +65,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = CDataType>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
PipelineVer>; PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
......
...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
Conv_G_{G}, Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{N}, Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{K}, Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{C}, Conv_C_{a_g_n_c_wis_lengths[2]},
input_spatial_lengths_{input_spatial_lengths}, input_spatial_lengths_{},
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{},
output_spatial_lengths_{output_spatial_lengths}, output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N, Conv_N_,
K, Conv_K_,
C, Conv_C_,
input_spatial_lengths, input_spatial_lengths_,
filter_spatial_lengths, filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths_,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K * Conv_N_ * Conv_K_ *
std::accumulate(begin(output_spatial_lengths), std::accumulate(begin(output_spatial_lengths_),
end(output_spatial_lengths), end(output_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ = compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C * Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths), std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths), end(input_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C * Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths), std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths), end(filter_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
} }
...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_; std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_; std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
...@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const InDataType* p_in_grid, static auto
WeiDataType* p_wei_grid, MakeArgument(const InDataType* p_in_grid,
const OutDataType* p_out_grid, WeiDataType* p_wei_grid,
const ck::index_t G, const OutDataType* p_out_grid,
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, InElementwiseOperation in_element_op,
const std::array<ck::index_t, NDimSpatial>& input_right_pads, WeiElementwiseOperation wei_element_op,
InElementwiseOperation in_element_op, OutElementwiseOperation out_element_op,
WeiElementwiseOperation wei_element_op, ck::index_t split_k)
OutElementwiseOperation out_element_op,
ck::index_t split_k)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{ {
if constexpr(is_GNHWK_GKYXC_GNHWC) const index_t WoStride = output_strides[4];
{ const auto KStride = Number<1>{};
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
} make_tuple(WoStride, KStride));
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{
const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{ {
if constexpr(is_GNHWK_GKYXC_GNHWC) const index_t NStride = input_strides[1];
{ const index_t HiStride = input_strides[3];
if constexpr(ConvBackwardWeightSpecialization == const index_t WiStride = input_strides[4];
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) const auto CStride = input_strides[2];
{ if constexpr(ConvBackwardWeightSpecialization ==
return make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
}
}
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{ {
const index_t NStride = input_strides[1]; return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
const index_t HiStride = input_strides[3]; make_tuple(WiStride, CStride));
const index_t WiStride = input_strides[4];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
}
} }
else else
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(NStride, HiStride, WiStride, CStride));
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto constexpr static auto
make_out_grid_desc(const ck::index_t N, make_out_grid_desc(const ck::index_t N,
...@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{ {
if constexpr(is_GNDHWK_GKZYXC_GNDHWC) const index_t WoStride = output_strides[5];
{ const auto KStride = Number<1>{};
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
} make_tuple(WoStride, KStride));
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
...@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{ {
if constexpr(is_GNDHWK_GKZYXC_GNDHWC) const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
if constexpr(ConvBackwardWeightSpecialization == return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) make_tuple(WiStride, CStride));
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
}
}
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{
const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
}
} }
else else
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Z,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const ck::index_t N,
...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* weights_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -409,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -409,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X; const index_t GemmN = C * X;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -527,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -527,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_gemmm_gemmn_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -542,6 +542,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -542,6 +542,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -576,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -576,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X * Y; const index_t GemmN = C * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -584,6 +588,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -584,6 +588,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -618,13 +623,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -618,13 +623,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -684,13 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -684,13 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // Padd
const auto wei_gemmm_gemmn_grid_desc = const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, make_tuple(make_pass_through_transform(GemmKBatch),
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, make_pass_through_transform(GemmK0),
wei_gemmm_gemmn_grid_desc); make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -703,6 +728,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -703,6 +728,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -744,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -744,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y; const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -752,6 +781,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -752,6 +781,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -786,13 +816,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -786,13 +816,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -861,13 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -861,13 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // Padd
const auto wei_gemmm_gemmn_grid_desc = const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, make_tuple(make_pass_through_transform(GemmKBatch),
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, make_pass_through_transform(GemmK0),
wei_gemmm_gemmn_grid_desc); make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} // function end } // function end
...@@ -887,6 +937,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -887,6 +937,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -910,6 +961,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -910,6 +961,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -933,6 +985,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -933,6 +985,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -1051,15 +1104,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1051,15 +1104,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1084,27 +1134,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1084,27 +1134,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{in_element_op}, b_element_op_{in_element_op},
c_element_op_{wei_element_op}, c_element_op_{wei_element_op},
Conv_G_{G}, Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{N}, Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{K}, Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{C}, Conv_C_{a_g_n_c_wis_lengths[2]},
output_spatial_lengths_{output_spatial_lengths}, input_spatial_lengths_{},
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{},
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N, Conv_N_,
K, Conv_K_,
C, Conv_C_,
input_spatial_lengths, input_spatial_lengths_,
filter_spatial_lengths, filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths_,
input_strides, a_g_n_c_wis_strides,
output_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1119,12 +1182,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1119,12 +1182,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = output_strides[0]; compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = input_strides[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideC_ = compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C * Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths), std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths), end(filter_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
...@@ -1163,8 +1226,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1163,8 +1226,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t Conv_N_; const index_t Conv_N_;
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_; std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_; const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
...@@ -1339,39 +1403,34 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1339,39 +1403,34 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const InDataType* p_in_grid, static auto
WeiDataType* p_wei_grid, MakeArgument(const InDataType* p_in_grid,
const OutDataType* p_out_grid, WeiDataType* p_wei_grid,
const ck::index_t G, const OutDataType* p_out_grid,
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, InElementwiseOperation in_element_op,
const std::array<ck::index_t, NDimSpatial>& input_right_pads, WeiElementwiseOperation wei_element_op,
InElementwiseOperation in_element_op, OutElementwiseOperation out_element_op,
WeiElementwiseOperation wei_element_op, const ck::index_t split_k)
OutElementwiseOperation out_element_op,
const ck::index_t split_k)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1390,15 +1449,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1390,15 +1449,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1411,15 +1467,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1411,15 +1467,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -214,13 +214,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -214,13 +214,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
using ComputeType = EDataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
EDataType, EDataType,
ComputeType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
......
...@@ -3,16 +3,7 @@ ...@@ -3,16 +3,7 @@
#pragma once #pragma once
#include <iostream> #include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -30,255 +21,32 @@ template <typename InDataType, ...@@ -30,255 +21,32 @@ template <typename InDataType,
ck::index_t ReduceMThreadSliceSize, ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize, ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC<InDataType,
: public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> OutDataType,
IndexDataType,
ComputeDataType,
ReduceOpId,
OutputIndex,
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorSize>
{ {
static constexpr auto I0 = Number<0>{}; using DevicePool3D = DevicePool3dFwd_NDHWC_NDHWC<InDataType,
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
static constexpr ck::index_t ReduceM_BlockTileSize =
ReduceMThreadClusterSize * ReduceMThreadSliceSize;
static constexpr ck::index_t ReduceK_BlockTileSize =
ReduceKThreadClusterSize * ReduceKThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = window_spatial_lengths[0];
const index_t X = window_spatial_lengths[1];
const index_t ConvStrideH = window_strides[0];
const index_t ConvStrideW = window_strides[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ReduceMRaw = N * Ho * Wo * C;
const index_t ReduceMPad =
math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw;
const index_t ReduceKRaw = Y * X;
const index_t ReduceKPad =
math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw;
// A[ReduceM, ReduceK]
const auto in_grid_desc_n_hi_wi_c =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_hi_wi_c,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_hip_wip_c,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_grid_desc_reducemraw_reducekraw =
transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
make_merge_transform(make_tuple(Y, X))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
in_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad),
make_right_pad_transform(ReduceKRaw, ReduceKPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM]
const auto out_grid_desc_reducemraw =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C));
const auto out_grid_desc_reducem = transform_tensor_descriptor(
out_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
}
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
// TODO
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev,
ck::index_t N,
ck::index_t C,
std::vector<ck::index_t>& input_spatial_lengths,
std::vector<ck::index_t>& window_spatial_lengths,
std::vector<ck::index_t>& output_spatial_lengths,
std::vector<ck::index_t>& window_strides,
std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t>& input_right_pads)
: p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{},
b_grid_desc_m_{}
{
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
C,
input_spatial_lengths,
window_spatial_lengths,
output_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1];
invariant_lowest_length_ = C;
reduce_lowest_length_ = window_spatial_lengths[1];
int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
}
const InDataType* p_in_dev_;
OutDataType* p_out_dev_;
IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_;
// for checking vector load/store
ck::index_t invariant_lowest_length_;
ck::index_t reduce_lowest_length_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
ComputeDataType,
IndexDataType, IndexDataType,
AGridDesc_M_K, ComputeDataType,
BGridDesc_M, ReduceOpId,
ReduceOperation, OutputIndex,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize, BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize, ReduceMThreadSliceSize,
ReduceKThreadSliceSize, ReduceKThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>; InSrcOutDstVectorSize>;
const auto kernel =
kernel_reduce_threadwise<gridwise_reduce,
OutputIndex,
true, // pooling need to return global index
false, // don't have index input
InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_m_,
arg.in_element_op_,
arg.acc_element_op_,
float(1),
arg.p_in_dev_,
nullptr,
float(0),
arg.p_out_dev_,
arg.p_out_indices_dev_);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
{
return (false);
}
return (true);
}
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
...@@ -286,62 +54,57 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -286,62 +54,57 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC std::vector<ck::index_t> input_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC std::vector<ck::index_t> output_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) override std::vector<ck::index_t> pooling_dims) override
{ {
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank ||
input_right_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3}) if(pooling_dims != std::vector<ck::index_t>{2, 3})
throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far"); throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
index_t N = input_lengths[0]; // NCHW to NCDHW
index_t C = input_lengths[1]; input_lengths.insert(input_lengths.begin() + 2, 1);
index_t Hi = input_lengths[2]; output_lengths.insert(output_lengths.begin() + 2, 1);
index_t Wi = input_lengths[3]; input_stride.insert(input_stride.begin() + 2, 0);
index_t Ho = output_lengths[2]; output_stride.insert(output_stride.begin() + 2, 0);
index_t Wo = output_lengths[3]; indices_stride.insert(indices_stride.begin() + 2, 0);
std::vector<ck::index_t> input_spatial_lengths = {Hi, Wi}; // YX to ZYX
std::vector<ck::index_t> output_spatial_lengths = {Ho, Wo}; window_lengths.insert(window_lengths.begin(), 1);
window_strides.insert(window_strides.begin(), 0);
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), window_dilations.insert(window_dilations.begin(), 0);
static_cast<OutDataType*>(p_out_dev), input_left_pads.insert(input_left_pads.begin(), 0);
static_cast<IndexDataType*>(p_out_indices_dev), input_right_pads.insert(input_right_pads.begin(), 0);
N,
C, pooling_dims = {2, 3, 4};
input_spatial_lengths,
window_lengths, return DevicePool3D::MakeArgumentPointer(p_in_dev,
output_spatial_lengths, p_out_dev,
window_strides, p_out_indices_dev,
input_left_pads, input_lengths,
input_right_pads); window_lengths,
} output_lengths,
input_stride,
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override output_stride,
{ indices_stride,
return std::make_unique<Invoker>(Invoker{}); window_strides,
} window_dilations,
input_left_pads,
std::string GetTypeString() const override input_right_pads,
{ pooling_dims);
auto str = std::stringstream();
// clang-format off
str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ",";
str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ",";
str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
} }
}; };
......
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -30,8 +32,15 @@ template <typename InDataType, ...@@ -30,8 +32,15 @@ template <typename InDataType,
ck::index_t MThreadSliceSize, ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize, ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C struct DevicePool3dFwd_NDHWC_NDHWC : public DevicePoolFwd<5,
: public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> 3,
InDataType,
OutDataType,
IndexDataType,
tensor_layout::convolution::NDHWC,
tensor_layout::convolution::NDHWC,
ReduceOpId,
OutputIndex>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -51,45 +60,48 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -51,45 +60,48 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced.
static constexpr index_t InSrcOutDstVectorDim = 0;
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_ncdhw_lengths,
ck::index_t C, std::vector<ck::index_t> output_ncdhw_lengths,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_ncdhw_stride,
std::vector<ck::index_t> window_spatial_lengths, std::vector<ck::index_t> output_ncdhw_stride,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> window_spatial_zyx_lengths,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_zyx_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> window_zyx_dilations,
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_left_dhw_pads,
std::vector<ck::index_t> input_right_dhw_pads)
{ {
const index_t Di = input_spatial_lengths[0]; const index_t N = input_ncdhw_lengths[0];
const index_t Hi = input_spatial_lengths[1]; const index_t C = input_ncdhw_lengths[1];
const index_t Wi = input_spatial_lengths[2]; const index_t Di = input_ncdhw_lengths[2];
const index_t Hi = input_ncdhw_lengths[3];
const index_t Wi = input_ncdhw_lengths[4];
const index_t Do = output_ncdhw_lengths[2];
const index_t Ho = output_ncdhw_lengths[3];
const index_t Wo = output_ncdhw_lengths[4];
const index_t Do = output_spatial_lengths[0]; const index_t Z = window_spatial_zyx_lengths[0];
const index_t Ho = output_spatial_lengths[1]; const index_t Y = window_spatial_zyx_lengths[1];
const index_t Wo = output_spatial_lengths[2]; const index_t X = window_spatial_zyx_lengths[2];
const index_t Z = window_spatial_lengths[0]; const index_t WindowStrideD = window_zyx_strides[0];
const index_t Y = window_spatial_lengths[1]; const index_t WindowStrideH = window_zyx_strides[1];
const index_t X = window_spatial_lengths[2]; const index_t WindowStrideW = window_zyx_strides[2];
const index_t ConvStrideD = window_strides[0]; const index_t WindowDilationD = window_zyx_dilations[0];
const index_t ConvStrideH = window_strides[1]; const index_t WindowDilationH = window_zyx_dilations[1];
const index_t ConvStrideW = window_strides[2]; const index_t WindowDilationW = window_zyx_dilations[2];
const index_t InLeftPadD = input_left_pads[0]; const index_t InLeftPadD = input_left_dhw_pads[0];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadH = input_left_dhw_pads[1];
const index_t InLeftPadW = input_left_pads[2]; const index_t InLeftPadW = input_left_dhw_pads[2];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadD = input_right_dhw_pads[0];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadH = input_right_dhw_pads[1];
const index_t InRightPadW = input_right_pads[2]; const index_t InRightPadW = input_right_dhw_pads[2];
const index_t MRaw = N * Do * Ho * Wo * C; const index_t MRaw = N * Do * Ho * Wo * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
...@@ -98,8 +110,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -98,8 +110,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// A[ReduceM, ReduceK] // A[ReduceM, ReduceK]
const auto in_grid_desc_n_di_hi_wi_c = const index_t Ni_stride = input_ncdhw_stride[0];
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); const index_t Ci_stride = input_ncdhw_stride[1];
const index_t Di_stride = input_ncdhw_stride[2];
const index_t Hi_stride = input_ncdhw_stride[3];
const index_t Wi_stride = input_ncdhw_stride[4];
const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride));
const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_di_hi_wi_c, in_grid_desc_n_di_hi_wi_c,
...@@ -113,11 +132,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -113,11 +132,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_dip_hip_wip_c, in_grid_desc_n_dip_hip_wip_c,
make_tuple(make_pass_through_transform(N), make_tuple(
make_embed_transform(make_tuple(Z, Do), make_tuple(I1, ConvStrideD)), make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
make_pass_through_transform(C)), make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1, 2>{}, Sequence<1, 2>{},
...@@ -139,8 +159,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -139,8 +159,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM] // B[ReduceM]
const auto out_grid_desc_reducemraw = const index_t No_stride = output_ncdhw_stride[0];
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo * C)); const index_t Co_stride = output_ncdhw_stride[1];
const index_t Do_stride = output_ncdhw_stride[2];
const index_t Ho_stride = output_ncdhw_stride[3];
const index_t Wo_stride = output_ncdhw_stride[4];
const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride));
const auto out_grid_desc_reducemraw = transform_tensor_descriptor(
out_grid_desc_n_do_ho_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto out_grid_desc_reducem = const auto out_grid_desc_reducem =
transform_tensor_descriptor(out_grid_desc_reducemraw, transform_tensor_descriptor(out_grid_desc_reducemraw,
...@@ -151,7 +184,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -151,7 +184,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
} }
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {})); using ABGridDescs =
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>; using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>; using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
...@@ -160,36 +195,41 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -160,36 +195,41 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
Argument(const InDataType* p_in_dev, Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev, OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev, IndexDataType* p_out_indices_dev,
ck::index_t N, std::vector<ck::index_t>& input_ncdhw_lengths,
ck::index_t C, std::vector<ck::index_t>& output_ncdhw_lengths,
std::vector<ck::index_t>& input_spatial_lengths, std::vector<ck::index_t>& input_ncdhw_stride,
std::vector<ck::index_t>& window_spatial_lengths, std::vector<ck::index_t>& output_ncdhw_stride,
std::vector<ck::index_t>& output_spatial_lengths, std::vector<ck::index_t>&, // indices_ncdhw_stride
std::vector<ck::index_t>& window_strides, std::vector<ck::index_t>& window_spatial_zyx_lengths,
std::vector<ck::index_t>& input_left_pads, std::vector<ck::index_t>& window_zyx_strides,
std::vector<ck::index_t>& input_right_pads) std::vector<ck::index_t>& window_zyx_dilations,
std::vector<ck::index_t>& input_left_dhw_pads,
std::vector<ck::index_t>& input_right_dhw_pads)
: p_in_dev_{p_in_dev}, : p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev}, p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev}, p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{}, a_grid_desc_m_k_{},
b_grid_desc_m_{} b_grid_desc_m_{},
input_ncdhw_lengths_{input_ncdhw_lengths},
output_ncdhw_lengths_{output_ncdhw_lengths},
input_ncdhw_stride_{input_ncdhw_stride},
output_ncdhw_stride_{output_ncdhw_stride}
{ {
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths,
C, output_ncdhw_lengths,
input_spatial_lengths, input_ncdhw_stride,
window_spatial_lengths, output_ncdhw_stride,
output_spatial_lengths, window_spatial_zyx_lengths,
window_strides, window_zyx_strides,
input_left_pads, window_zyx_dilations,
input_right_pads); input_left_dhw_pads,
input_right_dhw_pads);
a_grid_desc_m_k_ = descs[I0]; a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1]; b_grid_desc_m_ = descs[I1];
invariant_lowest_length_ = C; int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] *
window_spatial_zyx_lengths[2];
int32_t reduceLength =
window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2];
std::tie(in_element_op_, acc_element_op_) = std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength); reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
...@@ -200,17 +240,25 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -200,17 +240,25 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
IndexDataType* p_out_indices_dev_; IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_; BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_; AccElementwiseOperation acc_element_op_;
// for checking vector load/store // for checking vector load/store
ck::index_t invariant_lowest_length_; std::vector<ck::index_t> input_ncdhw_lengths_;
std::vector<ck::index_t> output_ncdhw_lengths_;
std::vector<ck::index_t> input_ncdhw_stride_;
std::vector<ck::index_t> output_ncdhw_stride_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
using gridwise_reduce = using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType, GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
...@@ -276,60 +324,66 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -276,60 +324,66 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) // C should be fastest dimension
{ if(pArg->input_ncdhw_stride_[1] != 1)
return false; return false;
for(int i = 0; i < InOutRank; ++i)
{
if(pArg->input_ncdhw_stride_[i] == 1 &&
pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
if(pArg->output_ncdhw_stride_[i] == 1 &&
pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
} }
return true; return true;
} }
std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_ncdhw_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_zyx_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_ncdhw_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC std::vector<ck::index_t> input_ncdhw_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC std::vector<ck::index_t> output_ncdhw_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC std::vector<ck::index_t> indices_ncdhw_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_zyx_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> window_zyx_dilations,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_left_dhw_pads,
std::vector<ck::index_t> input_right_dhw_pads,
std::vector<ck::index_t> pooling_dims) override std::vector<ck::index_t> pooling_dims) override
{ {
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank ||
input_right_dhw_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3, 4}) if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far"); throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
index_t N = input_lengths[0]; if(output_ncdhw_stride != indices_ncdhw_stride)
index_t C = input_lengths[1]; throw std::runtime_error(
index_t Di = input_lengths[2]; "output_ncdhw_stride need to be equal to indices_ncdhw_stride for now");
index_t Hi = input_lengths[3];
index_t Wi = input_lengths[4];
index_t Do = output_lengths[2];
index_t Ho = output_lengths[3];
index_t Wo = output_lengths[4];
std::vector<ck::index_t> input_spatial_lengths = {Di, Hi, Wi};
std::vector<ck::index_t> output_spatial_lengths = {Do, Ho, Wo};
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev), static_cast<OutDataType*>(p_out_dev),
static_cast<IndexDataType*>(p_out_indices_dev), static_cast<IndexDataType*>(p_out_indices_dev),
N, input_ncdhw_lengths,
C, output_ncdhw_lengths,
input_spatial_lengths, input_ncdhw_stride,
window_lengths, output_ncdhw_stride,
output_spatial_lengths, indices_ncdhw_stride,
window_strides, window_zyx_lengths,
input_left_pads, window_zyx_strides,
input_right_pads); window_zyx_dilations,
input_left_dhw_pads,
input_right_dhw_pads);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
...@@ -342,7 +396,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -342,7 +396,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<" << BlockSize << ","; str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
......
...@@ -75,6 +75,12 @@ struct PassThrough ...@@ -75,6 +75,12 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const __host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{ {
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
...@@ -17,6 +19,8 @@ ...@@ -17,6 +19,8 @@
namespace ck { namespace ck {
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -25,7 +29,8 @@ template <typename GridwiseGemm, ...@@ -25,7 +29,8 @@ template <typename GridwiseGemm,
typename CGridDesc_M0_M10_M11_N0_N10_N11, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -38,6 +43,13 @@ __global__ void ...@@ -38,6 +43,13 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
// DPP8 is currently only supported on gfx1030
#if !defined(__gfx1030__)
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return;
}
#endif
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -88,7 +100,8 @@ template <index_t BlockSize, ...@@ -88,7 +100,8 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
struct GridwiseGemmDl_km_kn_mn_v1r3 struct GridwiseGemmDl_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
__host__ __device__ static constexpr auto GetBlockwiseGemm()
{
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
else
{
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
...@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto c_m0_n0_block_cluster_idx = const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this forces index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
...@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>();
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
...@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
...@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 ...@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
......
...@@ -29,11 +29,13 @@ namespace ck { ...@@ -29,11 +29,13 @@ namespace ck {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// Assume: // Assume:
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <typename ABDataType, // FIXME: don't assume A/B have same datatype template <typename ADataType, // FIXME: don't assume A/B have same datatype
typename BDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
typename ComputeType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
#else
using ABDataTypeAdjusted = ABDataType;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -195,8 +186,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -195,8 +186,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max(a_block_space_size_aligned * sizeof(ADataType) +
sizeof(ABDataType), b_block_space_size_aligned * sizeof(BDataType),
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
...@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// check tensor size: cannot be larger than 2GB each // check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31); constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{ {
return false; return false;
...@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEElementwiseOperation_, typename CDEElementwiseOperation_,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid, __device__ static void Run(const ADataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
...@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence<1, AK0PerBlock, MPerBlock, AK1>, Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABDataType, ADataType,
ABDataTypeAdjusted, ComputeType,
decltype(a_grid_desc_kbatch_ak0_m_ak1), decltype(a_grid_desc_kbatch_ak0_m_ak1),
decltype(a_block_desc_kbatch_ak0_m_ak1), decltype(a_block_desc_kbatch_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence<1, BK0PerBlock, NPerBlock, BK1>, Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
ABDataType, BDataType,
ABDataTypeAdjusted, ComputeType,
decltype(b_grid_desc_kbatch_bk0_n_bk1), decltype(b_grid_desc_kbatch_bk0_n_bk1),
decltype(b_block_desc_kbatch_bk0_n_bk1), decltype(b_block_desc_kbatch_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<ABDataTypeAdjusted, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ABDataTypeAdjusted, ComputeType,
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared), static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
...@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const index_t KBatch, const index_t KBatch,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto p_a_grid = reinterpret_cast<const ABDataType*>(p_a_grid_); const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_); const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
using DsGridDesc_M_N = using DsGridDesc_M_N =
......
...@@ -35,13 +35,17 @@ __global__ void ...@@ -35,13 +35,17 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop> template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
...@@ -61,7 +65,8 @@ __global__ void ...@@ -61,7 +65,8 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -102,7 +107,8 @@ template <typename ALayout, ...@@ -102,7 +107,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem struct Argument : public tensor_operation::device::BaseArgument, public Problem
{ {
__host__ Argument(const FloatAB* p_a_grid_, __host__ Argument(const FloatA* p_a_grid_,
const FloatAB* p_b_grid_, const FloatB* p_b_grid_,
FloatC* p_c_grid_, FloatC* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
...@@ -479,8 +485,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -479,8 +485,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
} }
const FloatAB* p_a_grid; const FloatA* p_a_grid;
const FloatAB* p_b_grid; const FloatB* p_b_grid;
FloatC* p_c_grid; FloatC* p_c_grid;
}; };
...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned * sizeof(ComputeType) +
sizeof(FloatAB), b_block_space_size_aligned * sizeof(ComputeType)),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Problem& problem) const Problem& problem)
...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<AK0Number, MPerBlock, AK1Number>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, ComputeType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<BK0Number, NPerBlock, BK1Number>, Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, ComputeType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, ComputeType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st ...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{}; static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int __device__ static int
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
{ {
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
if(is_rightmost_block) if(is_rightmost_block)
{ {
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = int kPerThread = kRightmostBlock < K_BlockTileSize
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); ? 0
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st ...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
} }
else else
{ {
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int kPerBlock = math::integer_divide_ceil(k, kGridSize);
return KThreadSliceSize * (kPerBlock / K_BlockTileSize); return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
} }
} }
...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st ...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ = threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); kRaw,
k_grid_size,
block_k_cluster_id,
thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f); mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace ck {
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
bool ShareA,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t TK0 = TKLengths{}[I0];
static constexpr index_t TK1 = TKLengths{}[I1];
static constexpr index_t TM0 = TMLengths{}[I0];
static constexpr index_t TM1 = TMLengths{}[I1];
static constexpr index_t TN0 = TNLengths{}[I0];
static constexpr index_t TN1 = TNLengths{}[I1];
static_assert(TM0 == 1 && TN0 == 1);
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
static constexpr index_t shared_elems_per_lane =
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
constexpr int src_lane =
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment