Commit c7c47fd7 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Merge branch 'develop' into bwroblew/dpp8

parents f8eb91d7 578142db
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// In and Din = [N, C, Di, Hi, Wi]
// Out and Dout = [N, C, Do, Ho, Wo]
// Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W
template <typename DOutDataType,
typename DInDataType,
typename ComputeDataType,
ck::index_t BlockSize,
ck::index_t MThreadClusterSize,
ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DeviceAvgPool3dBwd_NDHWC_NDHWC : public DeviceAvgPoolBwd<3,
DOutDataType,
DInDataType,
tensor_layout::convolution::NDHWC,
tensor_layout::convolution::NDHWC>
{
static constexpr ck::index_t NDimSpatial = 3;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto
Make3DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
const std::vector<ck::index_t>& din_n_c_wos_length,
const std::vector<ck::index_t>& dout_n_c_wos_strides,
const std::vector<ck::index_t>& din_n_c_wos_strides,
const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads,
const std::vector<ck::index_t>& tildes)
{
index_t i_ztilde = tildes[0];
index_t i_ytilde = tildes[1];
index_t i_xtilde = tildes[2];
const index_t N = dout_n_c_wos_lengths[0];
const index_t C = dout_n_c_wos_lengths[1];
const index_t Di = din_n_c_wos_length[2];
const index_t Hi = din_n_c_wos_length[3];
const index_t Wi = din_n_c_wos_length[4];
const index_t Do = dout_n_c_wos_lengths[2];
const index_t Ho = dout_n_c_wos_lengths[3];
const index_t Wo = dout_n_c_wos_lengths[4];
const index_t Z = window_lengths[0];
const index_t Y = window_lengths[1];
const index_t X = window_lengths[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t ConvStrideD = window_strides[0];
const index_t ConvStrideH = window_strides[1];
const index_t ConvStrideW = window_strides[2];
const index_t ConvDilationD = window_dilations[0];
const index_t ConvDilationH = window_dilations[1];
const index_t ConvDilationW = window_dilations[2];
const auto out_n_do_ho_wo_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C),
make_tuple(dout_n_c_wos_strides[0],
dout_n_c_wos_strides[2],
dout_n_c_wos_strides[3],
dout_n_c_wos_strides[4],
dout_n_c_wos_strides[1]));
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on Tildes that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd =
math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// ReduceK is different for each Reduce
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// Problem size of reduction kernel
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// Out[ReduceM, ReduceK]
const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc =
transform_tensor_descriptor(
out_n_dop_hop_wop_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)),
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
out_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// In[ReduceM]
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
make_tuple(din_n_c_wos_strides[0],
din_n_c_wos_strides[2],
din_n_c_wos_strides[3],
din_n_c_wos_strides[4],
din_n_c_wos_strides[1]));
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(XTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<>{},
Sequence<3>{},
Sequence<4>{}));
const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto in_grid_desc_reducem =
transform_tensor_descriptor(in_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
}
using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0}));
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
// FIXME
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
using PassThrough = tensor_operation::element_wise::PassThrough;
using Div = tensor_operation::element_wise::UnaryDivide;
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
reduce::Add,
PassThrough,
Div,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
OutSrcInDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
struct Argument : public BaseArgument
{
Argument(const DOutDataType* p_dout,
DInDataType* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout},
p_din_grid_{p_din},
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
din_n_c_wos_length_{din_n_c_wos_length},
dout_n_c_wos_strides_{dout_n_c_wos_strides},
din_n_c_wos_strides_{din_n_c_wos_strides},
num_reduce_{1},
div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
{
std::vector<ck::index_t> Tildes(NDimSpatial);
for(int i = 0; i < NDimSpatial; ++i)
{
int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
Tildes[i] = window_strides[i] / GcdStrideDilation;
num_reduce_ *= Tildes[i];
}
for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde)
{
for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde)
{
// check slice is valid
const auto ZDotSlice =
math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]);
const auto YDotSlice =
math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]);
const auto XDotSlice =
math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]);
if(ZDotSlice * YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto dout_din_grid_desc =
Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads,
{i_ztilde, i_ytilde, i_xtilde});
dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
}
}
}
}
const DOutDataType* p_dout_grid_;
DInDataType* p_din_grid_;
std::vector<ck::index_t> dout_n_c_wos_lengths_;
std::vector<ck::index_t> din_n_c_wos_length_;
std::vector<ck::index_t> dout_n_c_wos_strides_;
std::vector<ck::index_t> din_n_c_wos_strides_;
int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
std::vector<DinGridDesc_M> din_grid_desc_m_container_;
Div div_element_op_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(index_t i = 0; i < arg.num_reduce_; i++)
{
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
false,
false,
false, // don't have index input
DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
PassThrough,
Div>;
ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
const index_t grid_size = (M / M_BlockTileSize);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.dout_grid_desc_m_k_container_[i],
arg.din_grid_desc_m_container_[i],
PassThrough{},
arg.div_element_op_,
float(1),
arg.p_dout_grid_,
nullptr,
float(0),
arg.p_din_grid_,
nullptr);
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
constexpr index_t Rank = NDimSpatial + 2;
int doutFastestDim = -1;
int dinFastestDim = -1;
for(int i = 0; i < Rank; ++i)
{
if(arg.dout_n_c_wos_strides_[i] == 1)
doutFastestDim = i;
if(arg.din_n_c_wos_strides_[i] == 1)
dinFastestDim = i;
}
if(doutFastestDim == -1 || dinFastestDim == -1)
{
if constexpr(InSrcOutDstVectorSize != 1)
return false;
}
else
{
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
return false;
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
return false;
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_c_wos_lengths,
std::vector<ck::index_t> din_n_c_wos_length,
std::vector<ck::index_t> dout_n_c_wos_strides,
std::vector<ck::index_t> din_n_c_wos_strides,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override
{
constexpr index_t Rank = NDimSpatial + 2;
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
throw std::runtime_error("dimension is incorrect");
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<DInDataType*>(p_din),
dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceAvgPool3dBwd<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, true>; ADataType,
BDataType,
CDataType,
true>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, false>; ADataType,
BDataType,
CDataType,
false>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -65,7 +65,8 @@ template <typename ALayout, ...@@ -65,7 +65,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = CDataType>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
PipelineVer>; PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
......
...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
Conv_G_{G}, Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{N}, Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{K}, Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{C}, Conv_C_{a_g_n_c_wis_lengths[2]},
input_spatial_lengths_{input_spatial_lengths}, input_spatial_lengths_{},
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{},
output_spatial_lengths_{output_spatial_lengths}, output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N, Conv_N_,
K, Conv_K_,
C, Conv_C_,
input_spatial_lengths, input_spatial_lengths_,
filter_spatial_lengths, filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths_,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K * Conv_N_ * Conv_K_ *
std::accumulate(begin(output_spatial_lengths), std::accumulate(begin(output_spatial_lengths_),
end(output_spatial_lengths), end(output_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ = compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C * Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths), std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths), end(input_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C * Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths), std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths), end(filter_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
} }
...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_; std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_; std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
...@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const InDataType* p_in_grid, static auto
WeiDataType* p_wei_grid, MakeArgument(const InDataType* p_in_grid,
const OutDataType* p_out_grid, WeiDataType* p_wei_grid,
const ck::index_t G, const OutDataType* p_out_grid,
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, InElementwiseOperation in_element_op,
const std::array<ck::index_t, NDimSpatial>& input_right_pads, WeiElementwiseOperation wei_element_op,
InElementwiseOperation in_element_op, OutElementwiseOperation out_element_op,
WeiElementwiseOperation wei_element_op, ck::index_t split_k)
OutElementwiseOperation out_element_op,
ck::index_t split_k)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{ {
if constexpr(is_GNHWK_GKYXC_GNHWC) const index_t WoStride = output_strides[4];
{ const auto KStride = Number<1>{};
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
} make_tuple(WoStride, KStride));
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{
const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{ {
if constexpr(is_GNHWK_GKYXC_GNHWC) const index_t NStride = input_strides[1];
{ const index_t HiStride = input_strides[3];
if constexpr(ConvBackwardWeightSpecialization == const index_t WiStride = input_strides[4];
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) const auto CStride = input_strides[2];
{ if constexpr(ConvBackwardWeightSpecialization ==
return make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
}
}
else if constexpr(is_NHWGK_GKYXC_NHWGC)
{ {
const index_t NStride = input_strides[1]; return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
const index_t HiStride = input_strides[3]; make_tuple(WiStride, CStride));
const index_t WiStride = input_strides[4];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
}
} }
else else
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(NStride, HiStride, WiStride, CStride));
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto constexpr static auto
make_out_grid_desc(const ck::index_t N, make_out_grid_desc(const ck::index_t N,
...@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t K, const ck::index_t K,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides) const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
{ {
if constexpr(is_GNDHWK_GKZYXC_GNDHWC) const index_t WoStride = output_strides[5];
{ const auto KStride = Number<1>{};
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
} make_tuple(WoStride, KStride));
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
}
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
...@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides) const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
{ {
if constexpr(is_GNDHWK_GKZYXC_GNDHWC) const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
if constexpr(ConvBackwardWeightSpecialization == return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) make_tuple(WiStride, CStride));
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
}
}
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
{
const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
make_tuple(WiStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
}
} }
else else
{ {
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const ck::index_t K,
const ck::index_t Z,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t C,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
make_tuple(KStride, CStride));
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const ck::index_t N,
...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* weights_strides */,
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */, const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -409,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -409,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X; const index_t GemmN = C * X;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -527,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -527,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_gemmm_gemmn_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -542,6 +542,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -542,6 +542,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -576,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -576,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X * Y; const index_t GemmN = C * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -584,6 +588,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -584,6 +588,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -618,13 +623,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -618,13 +623,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -684,13 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -684,13 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // Padd
const auto wei_gemmm_gemmn_grid_desc = const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, make_tuple(make_pass_through_transform(GemmKBatch),
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, make_pass_through_transform(GemmK0),
wei_gemmm_gemmn_grid_desc); make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -703,6 +728,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -703,6 +728,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
...@@ -744,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -744,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y; const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -752,6 +781,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -752,6 +781,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides); const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides); const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
...@@ -786,13 +816,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -786,13 +816,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_grid_desc);
} }
else else
{ {
...@@ -861,13 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -861,13 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // Padd
const auto wei_gemmm_gemmn_grid_desc = const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, make_tuple(make_pass_through_transform(GemmKBatch),
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, make_pass_through_transform(GemmK0),
wei_gemmm_gemmn_grid_desc); make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} // function end } // function end
...@@ -887,6 +937,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -887,6 +937,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -910,6 +961,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -910,6 +961,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -933,6 +985,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -933,6 +985,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
lengths, lengths,
strides, strides,
strides, strides,
strides,
params, params,
params, params,
params, params,
...@@ -1051,15 +1104,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1051,15 +1104,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1084,27 +1134,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1084,27 +1134,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{in_element_op}, b_element_op_{in_element_op},
c_element_op_{wei_element_op}, c_element_op_{wei_element_op},
Conv_G_{G}, Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{N}, Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{K}, Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{C}, Conv_C_{a_g_n_c_wis_lengths[2]},
output_spatial_lengths_{output_spatial_lengths}, input_spatial_lengths_{},
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{},
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N, Conv_N_,
K, Conv_K_,
C, Conv_C_,
input_spatial_lengths, input_spatial_lengths_,
filter_spatial_lengths, filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths_,
input_strides, a_g_n_c_wis_strides,
output_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1119,12 +1182,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1119,12 +1182,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = output_strides[0]; compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = input_strides[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideC_ = compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C * Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths), std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths), end(filter_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
...@@ -1163,8 +1226,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1163,8 +1226,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t Conv_N_; const index_t Conv_N_;
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_; std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_; const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
...@@ -1339,39 +1403,34 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1339,39 +1403,34 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const InDataType* p_in_grid, static auto
WeiDataType* p_wei_grid, MakeArgument(const InDataType* p_in_grid,
const OutDataType* p_out_grid, WeiDataType* p_wei_grid,
const ck::index_t G, const OutDataType* p_out_grid,
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, InElementwiseOperation in_element_op,
const std::array<ck::index_t, NDimSpatial>& input_right_pads, WeiElementwiseOperation wei_element_op,
InElementwiseOperation in_element_op, OutElementwiseOperation out_element_op,
WeiElementwiseOperation wei_element_op, const ck::index_t split_k)
OutElementwiseOperation out_element_op,
const ck::index_t split_k)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1390,15 +1449,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1390,15 +1449,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1411,15 +1467,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1411,15 +1467,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -35,13 +35,17 @@ __global__ void ...@@ -35,13 +35,17 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop> template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
...@@ -61,7 +65,8 @@ __global__ void ...@@ -61,7 +65,8 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -102,7 +107,8 @@ template <typename ALayout, ...@@ -102,7 +107,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem struct Argument : public tensor_operation::device::BaseArgument, public Problem
{ {
__host__ Argument(const FloatAB* p_a_grid_, __host__ Argument(const FloatA* p_a_grid_,
const FloatAB* p_b_grid_, const FloatB* p_b_grid_,
FloatC* p_c_grid_, FloatC* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
...@@ -479,8 +485,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -479,8 +485,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
} }
const FloatAB* p_a_grid; const FloatA* p_a_grid;
const FloatAB* p_b_grid; const FloatB* p_b_grid;
FloatC* p_c_grid; FloatC* p_c_grid;
}; };
...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned * sizeof(ComputeType) +
sizeof(FloatAB), b_block_space_size_aligned * sizeof(ComputeType)),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Problem& problem) const Problem& problem)
...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<AK0Number, MPerBlock, AK1Number>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, ComputeType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<BK0Number, NPerBlock, BK1Number>, Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, ComputeType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, ComputeType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// dinput descriptor in [N, C, Do, Ho, Wo] order
// doutput descriptor in [N, C, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename DInDataType,
typename DOutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceAvgPoolBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
: dinput_{dinput},
doutput_{doutput},
window_spatial_lengths_{window_spatial_lengths},
window_strides_{window_strides},
window_dilations_{window_dilations},
in_left_pads_{dinput_left_pads},
in_right_pads_{dinput_right_pads}
{
}
Tensor<DInDataType>& dinput_;
const Tensor<DOutDataType>& doutput_;
std::vector<ck::index_t> window_spatial_lengths_;
std::vector<index_t> window_strides_;
std::vector<index_t> window_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceAvgPoolBwd::Argument;
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 1, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
// Let input = x, outpu = y
// shape of x = [10], y = [6]
// window_size = 5, pad = 0, stride = 1, dilation = 1
// Forward:
// y0 = 1/5 * (x0 + x1 + x2 + x3 + x4)
// y1 = 1/5 * (x1 + x2 + x3 + x4 + x5)
// ...
// y5 = 1/5 * (x5 + x6 + x7 + x8 + x9)
// y6 = 1/5 * (x6 + x7 + x8 + x9)
// ...
// y9 = 1/5 * (x9)
// Backward:
// shape of dy = [6], dx = [10]
// dx0 = 1/5 * dy0
// dx1 = 1/5 * (dy0 + dy1)
// dx2 = 1/5 * (dy0 + dy1 + dy2)
// ...
// dx4 = 1/5 * (dy0 + dy1 + dy2 + dy3 + dy4)
// dx5 = 1/5 * (dy1 + dy2 + dy3 + dy4 + dy5)
// ...
// dx9 = 1/5 * (dy5 + dy6 + dy7 + dy8 + dy9)
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.doutput_.GetLengths()[2];
float v_acc = 0;
for(std::size_t x = 0; x < X; ++x)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(w_tmp % arg.window_strides_[0] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
}
}
}
v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 2, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[3];
float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(h_tmp % arg.window_strides_[0] == 0)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc +=
ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_nchw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 3, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(d_tmp % arg.window_strides_[0] == 0)
{
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp =
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.window_dilations_[2]);
if(w_tmp % arg.window_strides_[2] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(
arg.doutput_(n, c, do_, ho, wo));
}
}
}
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncdhw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3],
arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
return RunAvgPoolBwd<NDimSpatial>(arg);
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
{
if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || dinput_left_pads.size() != NDimSpatial ||
dinput_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return Argument{dinput,
doutput,
window_spatial_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceAvgPoolBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
...@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __bf16__
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
...@@ -36,7 +36,8 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( ...@@ -36,7 +36,8 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __fp16__
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -56,7 +57,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( ...@@ -56,7 +57,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __fp32__
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
...@@ -76,7 +78,8 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( ...@@ -76,7 +78,8 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __int8__
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemm<Col, std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Row, Row,
...@@ -120,7 +123,7 @@ void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( ...@@ -120,7 +123,7 @@ void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -151,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -151,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp32__
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> && if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>) is_same_v<CDataType, float>)
{ {
...@@ -176,8 +179,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -176,8 +179,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs); add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && #endif
is_same_v<CDataType, half_t>) #ifdef __fp16__
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
...@@ -200,8 +205,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -200,8 +205,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> && #endif
is_same_v<CDataType, bhalf_t>) #ifdef __bf16__
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
...@@ -224,8 +231,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -224,8 +231,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs); add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && #endif
is_same_v<CDataType, int8_t>) #ifdef __int8__
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
...@@ -248,7 +257,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -248,7 +257,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs); add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add; using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -137,3 +137,4 @@ struct DeviceOperationInstanceFactory< ...@@ -137,3 +137,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -91,3 +91,4 @@ struct DeviceOperationInstanceFactory< ...@@ -91,3 +91,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -58,7 +58,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_ ...@@ -58,7 +58,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif
#ifdef __bf16__
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -100,7 +101,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf ...@@ -100,7 +101,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif
template <typename ADataType, template <typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
...@@ -147,7 +148,7 @@ struct DeviceOperationInstanceFactory< ...@@ -147,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> && is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
...@@ -164,6 +165,8 @@ struct DeviceOperationInstanceFactory< ...@@ -164,6 +165,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __bf16__
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> && else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> && is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
...@@ -180,6 +183,7 @@ struct DeviceOperationInstanceFactory< ...@@ -180,6 +183,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row, std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col, Col,
...@@ -111,3 +111,4 @@ struct DeviceOperationInstanceFactory< ...@@ -111,3 +111,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -19,7 +19,7 @@ namespace ck { ...@@ -19,7 +19,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances( void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, Row,
...@@ -123,7 +123,8 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan ...@@ -123,7 +123,8 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __int8__
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col, std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row, Row,
...@@ -227,7 +228,7 @@ void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances ...@@ -227,7 +228,7 @@ void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -262,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -262,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
...@@ -295,6 +296,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -295,6 +296,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __int8__
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, int8_t>) is_same_v<EDataType, int8_t>)
{ {
...@@ -327,7 +330,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche ...@@ -327,7 +330,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory< ...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -58,7 +58,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g ...@@ -58,7 +58,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif
#ifdef __bf16__
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -100,6 +101,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf ...@@ -100,6 +101,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif
template <typename ADataType, template <typename ADataType,
typename B0DataType, typename B0DataType,
...@@ -146,7 +148,7 @@ struct DeviceOperationInstanceFactory< ...@@ -146,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
...@@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory< ...@@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __bf16__
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> && else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>) is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{ {
...@@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory< ...@@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp32__
// float // float
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -65,7 +65,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn ...@@ -65,7 +65,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
#ifdef __fp64__
// double // double
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -114,7 +115,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn ...@@ -114,7 +115,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
// Contraction + Bilinear // Contraction + Bilinear
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
...@@ -149,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -149,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp32__
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> && if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<DDataType, float> && is_same_v<EDataType, float>) is_same_v<DDataType, float> && is_same_v<EDataType, float>)
{ {
...@@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __fp64__
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> && if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<DDataType, double> && is_same_v<EDataType, double>) is_same_v<DDataType, double> && is_same_v<EDataType, double>)
{ {
...@@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp32__
// float // float
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -65,7 +65,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc ...@@ -65,7 +65,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale>>>& instances);
#endif
#ifdef __fp64__
// double // double
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -114,7 +115,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc ...@@ -114,7 +115,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale>>>& instances);
#endif
// Contraction + Scale // Contraction + Scale
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
...@@ -148,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -148,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp32__
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> && if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<EDataType, float>) is_same_v<EDataType, float>)
{ {
...@@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __fp64__
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> && if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<EDataType, double>) is_same_v<EDataType, double>)
{ {
...@@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment