Unverified Commit 49180fd6 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Grouped 3d conv backward data support (#799)

* Grouped 3d conv backward data support

* Fix comments
parent f82bd593
......@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch(s)
{
case ConvolutionBackwardDataSpecialization::Default: return "Default";
case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0:
return "FFilter1x1Stride1Pad0";
case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
default: return "Unrecognized specialization!";
}
}
......
......@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEElementwiseOp>
{
// FIXME
static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now");
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
......@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
});
static constexpr auto NonSpatialDimsNum = Number<3>{};
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
// problem definition
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t Z = b_g_k_c_xs_lengths[ZIdx];
const index_t Y = b_g_k_c_xs_lengths[YIdx];
const index_t X = b_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideH = conv_filter_strides_[0];
const index_t ConvStrideW = conv_filter_strides_[1];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations_[0];
const index_t ConvDilationW = conv_filter_dilations_[1];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0)
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
continue;
}
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
// check slice is valid
const auto ZDotSlice =
NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice * ZDotSlice <= 0)
{
continue;
}
std::array<index_t, NDimSpatial> tildes;
if constexpr(NDimSpatial == 2)
{
tildes = {i_ytilde, i_xtilde};
}
else if constexpr(NDimSpatial == 3)
{
tildes = {i_ztilde, i_ytilde, i_xtilde};
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
// desc for problem definition
const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
tildes);
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
// desc for problem definition
const auto a_grid_desc_m_k =
transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k =
transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
}
}
}
}
......@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{
......@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// vector load for B matrix from global memory to LDS
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
{
......@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<DLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<DLayout, tensor_layout::convolution::GC> ||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
......@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
{
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
......
......@@ -16,7 +16,7 @@ namespace device {
namespace instance {
// conv2d backward data
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -30,7 +30,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -44,7 +44,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -58,7 +58,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -72,7 +72,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -100,6 +100,91 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough,
PassThrough>>>& instances);
// conv3d backward data
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
......@@ -139,43 +224,96 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
if constexpr(NumDimSpatial == 2)
{
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
else if constexpr(NumDimSpatial == 3)
{
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
}
}
......
......@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataDefault>{});
device_grouped_conv_bwd_data_xdl_bf16_instances<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
device_grouped_conv_bwd_data_xdl_bf16_instances<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
......
......@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataDefault>{});
device_grouped_conv_bwd_data_xdl_f16_instances<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
device_grouped_conv_bwd_data_xdl_f16_instances<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
......
......@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
......@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataDefault>{});
device_grouped_conv_bwd_data_xdl_f32_instances<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
device_grouped_conv_bwd_data_xdl_f32_instances<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
......
......@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
device_grouped_conv_bwd_data_xdl_bf16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
device_grouped_conv_bwd_data_xdl_bf16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
......
......@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
device_grouped_conv_bwd_data_xdl_f16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
device_grouped_conv_bwd_data_xdl_f16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
......
......@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
......@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
device_grouped_conv_bwd_data_xdl_f32_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
device_grouped_conv_bwd_data_xdl_f32_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
......
add_instance_library(device_grouped_conv3d_bwd_data_instance
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -81,4 +81,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_in
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment