Unverified Commit 475188ca authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add grouped conv bwd weight dl instances and new layout (#897)

* Add grouped conv bwd weight dl instances and new layout

* Add M and N padding

* Remove todo comment

* Enable grouped conv fwd dl k,c=1 generic instance

* Comment fixes
parent a66d14ed
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
using InDataType = F16; using InDataType = F16;
using WeiDataType = F16; using WeiDataType = F16;
...@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough; ...@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough;
using OutElementOp = PassThrough; using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl<
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl< NDimSpatial, // NDimSpatial
NDimSpatial, // NDimSpatial ck::tuple_element_t<NDimSpatial - 1,
InDataType, // InDataType ck::Tuple<ck::tensor_layout::convolution::GNWC,
WeiDataType, // WeiDataType ck::tensor_layout::convolution::GNHWC,
OutDataType, // OutDataType ck::tensor_layout::convolution::GNDHWC>>, // InLayout
AccDataType, // AccDataType ck::tuple_element_t<NDimSpatial - 1,
InElementOp, // InElementwiseOperation ck::Tuple<ck::tensor_layout::convolution::GKXC,
WeiElementOp, // WeiElementwiseOperation ck::tensor_layout::convolution::GKYXC,
OutElementOp, // OutElementwiseOperation ck::tensor_layout::convolution::GKZYXC>>, // WeiLayout
ConvBwdWeightDefault, // ConvBackwardWeightSpecialization ck::tuple_element_t<NDimSpatial - 1,
256, // BlockSize ck::Tuple<ck::tensor_layout::convolution::GNWK,
128, // MPerBlock ck::tensor_layout::convolution::GNHWK,
128, // NPerBlock ck::tensor_layout::convolution::GNDHWK>>, // OutLayout
16, // K0PerBlock InDataType, // InDataType
2, // K1 WeiDataType, // WeiDataType
4, // M1PerThread OutDataType, // OutDataType
4, // N1PerThread AccDataType, // AccDataType
1, // KPerThread InElementOp, // InElementwiseOperation
S<8, 2>, // M1N1ThreadClusterM1Xs WeiElementOp, // WeiElementwiseOperation
S<8, 2>, // M1N1ThreadClusterN1Xs OutElementOp, // OutElementwiseOperation
S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 ConvBwdWeightDefault, // ConvBackwardWeightSpecialization
S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 256, // BlockSize
S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder 128, // MPerBlock
S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder 128, // NPerBlock
S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 16, // K0PerBlock
S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder 2, // K1
S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 4, // M1PerThread
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 4, // N1PerThread
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 1, // KPerThread
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder S<8, 2>, // M1N1ThreadClusterM1Xs
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder S<8, 2>, // M1N1ThreadClusterN1Xs
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder
5, // CThreadTransferSrcDstVectorDim S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
4>; // CThreadTransferDstScalarPerVector S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector
#include "run_grouped_conv_bwd_weight_example.inc" #include "run_grouped_conv_bwd_weight_example.inc"
......
...@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial> ...@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial>
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_param) const ck::utils::conv::ConvParam& conv_param)
{ {
ck::index_t split_k;
// Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1 // Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support constexpr ck::index_t split_k = 1;
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{
split_k = 2;
}
else
{
split_k = 1;
}
const auto in_g_n_c_wis_desc = const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -72,6 +73,9 @@ __global__ void ...@@ -72,6 +73,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -96,9 +100,27 @@ __global__ void ...@@ -96,9 +100,27 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
compute_ptr_offset_of_batch.GetCPtrOffset(0);
#endif
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -134,29 +156,46 @@ template <ck::index_t NDimSpatial, ...@@ -134,29 +156,46 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpatial,
: public DeviceGroupedConvBwdWeight< InLayout,
NDimSpatial, WeiLayout,
ck::tuple_element_t<NDimSpatial - 1, OutLayout,
ck::Tuple<ck::tensor_layout::convolution::GNWC, InDataType,
ck::tensor_layout::convolution::GNHWC, WeiDataType,
ck::tensor_layout::convolution::GNDHWC>>, OutDataType,
ck::tuple_element_t<NDimSpatial - 1, InElementwiseOperation,
ck::Tuple<ck::tensor_layout::convolution::GKXC, WeiElementwiseOperation,
ck::tensor_layout::convolution::GKYXC, OutElementwiseOperation>
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl; // 1d
static constexpr bool is_NWGK_GKXC_NWGC =
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
static constexpr bool is_GNWK_GKXC_GNWC =
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
// 2d
static constexpr bool is_NHWGK_GKYXC_NHWGC =
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
static constexpr bool is_GNHWK_GKYXC_GNHWC =
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
// 3d
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -176,6 +215,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -176,6 +215,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto spatial_offset = I3;
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number; static constexpr auto GemmK1Number = K1Number;
...@@ -195,12 +236,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -195,12 +236,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -209,90 +250,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -209,90 +250,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Wi = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Wo = output_spatial_lengths[0]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t X = filter_spatial_lengths[0]; const index_t C = a_g_n_c_wis_lengths[I2];
const index_t InLeftPadW = input_left_pads[0]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset];
const index_t InRightPadW = input_right_pads[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset];
const index_t ConvStrideW = conv_filter_strides[0]; const index_t X = b_g_k_c_xs_lengths[spatial_offset];
const index_t ConvDilationW = conv_filter_dilations[0]; const index_t InLeftPadW = input_left_pads[I0];
const index_t InRightPadW = input_right_pads[I0];
const index_t ConvStrideW = conv_filter_strides[I0];
const index_t ConvDilationW = conv_filter_dilations[I0];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
const index_t GemmKTotal = N * Wo; const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weights tensor // C: weights tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_wi_c_grid_desc = const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -321,38 +374,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -321,38 +374,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -361,103 +419,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -361,103 +419,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Hi = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Wi = input_spatial_lengths[1]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Ho = output_spatial_lengths[0]; const index_t Hi = a_g_n_c_wis_lengths[spatial_offset];
const index_t Wo = output_spatial_lengths[1]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset];
const index_t Y = filter_spatial_lengths[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t X = filter_spatial_lengths[1]; const index_t Y = b_g_k_c_xs_lengths[spatial_offset];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1]; const index_t InLeftPadH = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[I1];
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadH = input_right_pads[I0];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadW = input_right_pads[I1];
const index_t ConvStrideH = conv_filter_strides[I0];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[I1];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvDilationH = conv_filter_dilations[I0];
const index_t ConvDilationW = conv_filter_dilations[I1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1]; const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
const index_t GemmKTotal = N * Ho * Wo; const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -488,39 +554,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -488,39 +554,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -529,110 +600,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -529,110 +600,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Di = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Hi = input_spatial_lengths[1]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t Wi = input_spatial_lengths[2]; const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0];
const index_t Do = output_spatial_lengths[0]; const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = output_spatial_lengths[1]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2];
const index_t Wo = output_spatial_lengths[2]; const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t Z = filter_spatial_lengths[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2];
const index_t Y = filter_spatial_lengths[1]; const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0];
const index_t X = filter_spatial_lengths[2]; const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadD = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[2]; const index_t InLeftPadH = input_left_pads[I1];
const index_t InLeftPadW = input_left_pads[I2];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadD = input_right_pads[I0];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadH = input_right_pads[I1];
const index_t InRightPadW = input_right_pads[2]; const index_t InRightPadW = input_right_pads[I2];
const index_t ConvStrideD = conv_filter_strides[I0];
const index_t ConvStrideD = conv_filter_strides[0]; const index_t ConvStrideH = conv_filter_strides[I1];
const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[I2];
const index_t ConvStrideW = conv_filter_strides[2]; const index_t ConvDilationD = conv_filter_dilations[I0];
const index_t ConvDilationH = conv_filter_dilations[I1];
const index_t ConvDilationD = conv_filter_dilations[0]; const index_t ConvDilationW = conv_filter_dilations[I2];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2]; const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InDStride = a_g_n_c_wis_strides[spatial_offset];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
const index_t GemmKTotal = N * Do * Ho * Wo; const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_tuple(N, Di, Hi, Wi, C),
make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -672,27 +753,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -672,27 +753,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
...@@ -701,22 +787,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -701,22 +787,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1},
1, {1, 1, 1},
1, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
...@@ -785,11 +871,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -785,11 +871,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -809,38 +895,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -809,38 +895,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
Conv_G_{a_g_n_c_wis_lengths[0]}, Conv_G_{a_g_n_c_wis_lengths[I0]},
Conv_N_{a_g_n_c_wis_lengths[1]}, Conv_K_{b_g_k_c_xs_lengths[I1]},
Conv_K_{b_g_k_c_xs_lengths[1]}, Conv_C_{a_g_n_c_wis_lengths[I2]},
Conv_C_{a_g_n_c_wis_lengths[2]}, filter_lengths_{b_g_k_c_xs_lengths},
input_spatial_lengths_{},
filter_spatial_lengths_{},
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, a_g_n_c_wis_lengths, // input
Conv_K_, a_g_n_c_wis_strides,
Conv_C_, b_g_k_c_xs_lengths, // weight
input_spatial_lengths_, b_g_k_c_xs_strides,
filter_spatial_lengths_, e_g_n_k_wos_lengths, // output
output_spatial_lengths_, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -863,24 +935,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -863,24 +935,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0];
Conv_N_ * Conv_K_ * compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0];
std::accumulate(begin(output_spatial_lengths_), compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0];
end(output_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -908,13 +965,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -908,13 +965,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
const index_t Conv_G_; const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_; std::array<ck::index_t, NDimSpatial + 3> filter_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
...@@ -1036,10 +1090,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1036,10 +1090,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || // DL version only supports split_k equal to 1
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(arg.k_batch_ != 1)
ck::get_device_name() == "gfx1102")) return false;
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
{ {
return false; return false;
} }
...@@ -1050,8 +1108,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1050,8 +1108,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++) for(int i = 0; i < NDimSpatial; i++)
{ {
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && if(!(arg.filter_lengths_[spatial_offset + i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
{ {
return false; return false;
} }
...@@ -1206,7 +1265,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1206,7 +1265,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl" str << "DeviceGroupedConvBwdWeight_Dl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f ...@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c); c);
} }
template <>
__device__ void inner_product<bhalf_t, bhalf_t, float>(const bhalf_t& a, const bhalf_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <>
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
{
inner_product(type_convert<float>(a), type_convert<float>(b), c);
}
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_dl_f32_instances = std::tuple<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_dl_f16_instances = std::tuple<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_dl_bf16_instances = std::tuple<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -55,8 +55,8 @@ using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple< ...@@ -55,8 +55,8 @@ using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple<
// ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instances // generic instances
// TODO: Change to ScalarPerVector = 1 when inner_product<half_t, half_t, float> will be supported // TODO: Change to ScalarPerVector = 1 when inner_product<half_t, half_t, float> will be supported
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 2, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 1, 2>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 1, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 1>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
......
...@@ -17,6 +17,7 @@ namespace tensor_operation { ...@@ -17,6 +17,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// xdl
// conv1d backward weight // conv1d backward weight
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
...@@ -200,6 +201,228 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances ...@@ -200,6 +201,228 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef DL_KERNELS
// dl
// conv1d backward weight
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -242,21 +465,54 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -242,21 +465,54 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
} }
else if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, NWGK>)
{
#ifdef DL_KERNELS
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
...@@ -266,12 +522,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -266,12 +522,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs); op_ptrs);
} }
...@@ -279,6 +543,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -279,6 +543,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
...@@ -289,12 +557,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -289,12 +557,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs); op_ptrs);
} }
...@@ -302,6 +578,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -302,6 +578,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
...@@ -315,12 +595,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -315,12 +595,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs); op_ptrs);
} }
...@@ -328,6 +616,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -328,6 +616,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
...@@ -338,12 +630,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -338,12 +630,20 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs); op_ptrs);
} }
...@@ -351,6 +651,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -351,6 +651,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
......
add_instance_library(device_grouped_conv1d_bwd_weight_instance set(GROUPED_CONV1D_BWD_WEIGHT
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp)
)
if(DL_KERNELS)
list(APPEND GROUPED_CONV1D_BWD_WEIGHT
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp)
endif()
add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_bf16_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_bf16_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f16_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f16_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f32_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f32_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_bf16_instances<1,
NWGC,
GKXC,
NWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_bf16_instances<1,
NWGC,
GKXC,
NWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f16_instances<1,
NWGC,
GKXC,
NWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f16_instances<1,
NWGC,
GKXC,
NWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f32_instances<1,
NWGC,
GKXC,
NWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f32_instances<1,
NWGC,
GKXC,
NWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
......
add_instance_library(device_grouped_conv2d_bwd_weight_instance set(GROUPED_CONV2D_BWD_WEIGHT
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp)
)
if(DL_KERNELS)
list(APPEND GROUPED_CONV2D_BWD_WEIGHT
device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp)
endif()
add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_bf16_instances<2,
GNHWC,
GKYXC,
GNHWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_bf16_instances<2,
GNHWC,
GKYXC,
GNHWK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f16_instances<2,
GNHWC,
GKYXC,
GNHWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_dl_f16_instances<2,
GNHWC,
GKYXC,
GNHWK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment