Unverified Commit 49180fd6 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Grouped 3d conv backward data support (#799)

* Grouped 3d conv backward data support

* Fix comments
parent f82bd593
...@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat ...@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch(s) switch(s)
{ {
case ConvolutionBackwardDataSpecialization::Default: return "Default"; case ConvolutionBackwardDataSpecialization::Default: return "Default";
case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
return "FFilter1x1Stride1Pad0";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
......
...@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEElementwiseOp> CDEElementwiseOp>
{ {
// FIXME // FIXME
static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now"); static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
...@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -491,130 +492,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
}); });
static constexpr auto NonSpatialDimsNum = Number<3>{};
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
: Number<NonSpatialDimsNum + 2>{};
// problem definition // problem definition
const index_t Y = b_g_k_c_xs_lengths[3]; const index_t Z = b_g_k_c_xs_lengths[ZIdx];
const index_t X = b_g_k_c_xs_lengths[4]; const index_t Y = b_g_k_c_xs_lengths[YIdx];
const index_t X = b_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideH = conv_filter_strides_[0]; const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides_[1]; const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations_[0]; const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations_[1]; const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{ {
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0) for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{ {
continue; // check slice is valid
} const auto ZDotSlice =
NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
const auto a_grid_desc_ak0_m_ak1 = const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>( const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, if(YDotSlice * XDotSlice * ZDotSlice <= 0)
b_g_k_c_xs_lengths, {
b_g_k_c_xs_strides, continue;
e_g_n_c_wis_lengths, }
e_g_n_c_wis_strides,
conv_filter_strides, std::array<index_t, NDimSpatial> tildes;
conv_filter_dilations, if constexpr(NDimSpatial == 2)
input_left_pads, {
input_right_pads, tildes = {i_ytilde, i_xtilde};
{i_ytilde, i_xtilde}); }
else if constexpr(NDimSpatial == 3)
const auto b_grid_desc_bk0_n_bk1 = {
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>( tildes = {i_ztilde, i_ytilde, i_xtilde};
a_g_n_k_wos_lengths, }
a_g_n_k_wos_strides, else
b_g_k_c_xs_lengths, {
b_g_k_c_xs_strides, throw std::runtime_error("wrong! only implemented for 2D and 3D now");
e_g_n_c_wis_lengths, }
e_g_n_c_wis_strides,
conv_filter_strides, const auto a_grid_desc_ak0_m_ak1 =
conv_filter_dilations, transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths, a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, a_g_n_k_wos_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i], e_g_n_c_wis_lengths,
ds_g_n_c_wis_strides[i], e_g_n_c_wis_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{i_ytilde, i_xtilde}); tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
{i_ytilde, i_xtilde});
// desc for problem definition
const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
e_grid_desc_m_n)); a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ds_grid_desc_m_n(i) =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
// desc for problem definition
const auto a_grid_desc_m_k =
transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
const auto b_grid_desc_n_k =
transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
block_2_etile_map_container_.push_back(block_2_etile_map);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
}
} }
} }
} }
...@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -803,7 +846,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector load for A matrix from global memory to LDS // vector load for A matrix from global memory to LDS
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> || if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>) is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{ {
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{ {
...@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -816,7 +861,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
// vector load for B matrix from global memory to LDS // vector load for B matrix from global memory to LDS
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>) if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{ {
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
{ {
...@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -835,7 +881,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> || if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<DLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<DLayout, tensor_layout::convolution::NHWGC> || is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> || is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<DLayout, tensor_layout::convolution::GC> || is_same_v<DLayout, tensor_layout::convolution::GC> ||
is_same_v<DLayout, tensor_layout::convolution::G_C>) is_same_v<DLayout, tensor_layout::convolution::G_C>)
...@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -859,7 +907,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// vector store for E // vector store for E
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> || if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>) is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
{ {
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
......
...@@ -16,7 +16,7 @@ namespace device { ...@@ -16,7 +16,7 @@ namespace device {
namespace instance { namespace instance {
// conv2d backward data // conv2d backward data
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -30,7 +30,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -30,7 +30,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -44,7 +44,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -44,7 +44,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -58,7 +58,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( ...@@ -58,7 +58,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -72,7 +72,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -72,7 +72,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -100,6 +100,91 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( ...@@ -100,6 +100,91 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// conv3d backward data
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename OutLayout, typename OutLayout,
typename WeiLayout, typename WeiLayout,
...@@ -139,43 +224,96 @@ struct DeviceOperationInstanceFactory< ...@@ -139,43 +224,96 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2)
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{ {
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
{ is_same_v<OutLayout, GNHWK>)
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
} }
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutDataType, BF16>) is_same_v<OutLayout, NHWGK>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
} }
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> && else if constexpr(NumDimSpatial == 3)
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>) if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
{ is_same_v<OutLayout, GNDHWK>)
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
} }
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> && else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutDataType, BF16>) is_same_v<OutLayout, NDHWGK>)
{ {
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
} }
} }
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] // Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( ...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<GNHWK, device_grouped_conv_bwd_data_xdl_bf16_instances<2,
GKYXC, GNHWK,
Empty_Tuple, GKYXC,
GNHWC, Empty_Tuple,
ConvBwdDataDefault>{}); GNHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<GNHWK, device_grouped_conv_bwd_data_xdl_bf16_instances<2,
GKYXC, GNHWK,
Empty_Tuple, GKYXC,
GNHWC, Empty_Tuple,
ConvBwdDataFilter1x1Stride1Pad0>{}); GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] // Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<GNHWK, device_grouped_conv_bwd_data_xdl_f16_instances<2,
GKYXC, GNHWK,
Empty_Tuple, GKYXC,
GNHWC, Empty_Tuple,
ConvBwdDataDefault>{}); GNHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<GNHWK, device_grouped_conv_bwd_data_xdl_f16_instances<2,
GKYXC, GNHWK,
Empty_Tuple, GKYXC,
GNHWC, Empty_Tuple,
ConvBwdDataFilter1x1Stride1Pad0>{}); GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] // Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK, GNHWK,
GKYXC, GKYXC,
...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<GNHWK, device_grouped_conv_bwd_data_xdl_f32_instances<2,
GKYXC, GNHWK,
Empty_Tuple, GKYXC,
GNHWC, Empty_Tuple,
ConvBwdDataDefault>{}); GNHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<GNHWK, device_grouped_conv_bwd_data_xdl_f32_instances<2,
GKYXC, GNHWK,
Empty_Tuple, GKYXC,
GNHWC, Empty_Tuple,
ConvBwdDataFilter1x1Stride1Pad0>{}); GNHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] // Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( ...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<NHWGK, device_grouped_conv_bwd_data_xdl_bf16_instances<2,
GKYXC, NHWGK,
Empty_Tuple, GKYXC,
NHWGC, Empty_Tuple,
ConvBwdDataDefault>{}); NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_bf16_instances<NHWGK, device_grouped_conv_bwd_data_xdl_bf16_instances<2,
GKYXC, NHWGK,
Empty_Tuple, GKYXC,
NHWGC, Empty_Tuple,
ConvBwdDataFilter1x1Stride1Pad0>{}); NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] // Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<NHWGK, device_grouped_conv_bwd_data_xdl_f16_instances<2,
GKYXC, NHWGK,
Empty_Tuple, GKYXC,
NHWGC, Empty_Tuple,
ConvBwdDataDefault>{}); NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f16_instances<NHWGK, device_grouped_conv_bwd_data_xdl_f16_instances<2,
GKYXC, NHWGK,
Empty_Tuple, GKYXC,
NHWGC, Empty_Tuple,
ConvBwdDataFilter1x1Stride1Pad0>{}); NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_data_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] // Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK, NHWGK,
GKYXC, GKYXC,
...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -26,19 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<NHWGK, device_grouped_conv_bwd_data_xdl_f32_instances<2,
GKYXC, NHWGK,
Empty_Tuple, GKYXC,
NHWGC, Empty_Tuple,
ConvBwdDataDefault>{}); NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_data_xdl_f32_instances<NHWGK, device_grouped_conv_bwd_data_xdl_f32_instances<2,
GKYXC, NHWGK,
Empty_Tuple, GKYXC,
NHWGC, Empty_Tuple,
ConvBwdDataFilter1x1Stride1Pad0>{}); NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
add_instance_library(device_grouped_conv3d_bwd_data_instance
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -81,4 +81,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_in ...@@ -81,4 +81,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_in
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment