Unverified Commit 261f1759 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Support large batch tensors in grouped conv bwd data (#1711)

* Support large batch tensors in grouped conv bwd data

* Fix multiD

* fixes

* fixes

* fixes
parent 58e7f37f
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1; static constexpr index_t KPerBlock = K0PerBlock * K1;
static constexpr auto transform_conv_to_gemm = using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
TransformConvBwdDataToGemm_v1<NDimSpatial, ConvBackwardDataSpecialization,
ConvBackwardDataSpecialization, K1,
K1, K1,
K1, MPerBlock,
MPerBlock, NPerBlock,
NPerBlock, KPerBlock,
KPerBlock, true /* DoPadGemmM */,
true /* DoPadGemmM */, true /* DoPadGemmN */,
true /* DoPadGemmN */>{}; ALayout,
BLayout,
static auto GetDummyABDsEGridDescriptor() ELayout>;
{
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto ds_grid_desc_m_n =
generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
Number<NumDTensor>{});
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple( return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
} }
// desc // desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>; using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>; using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
...@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, /*ds_g_n_c_wis_lengths*/,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, ds_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
...@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
...@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
tildes = {i_ztilde, i_ytilde, i_xtilde}; tildes = {i_ztilde, i_ytilde, i_xtilde};
} }
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>( conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>( conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
DsGridDesc_M_N ds_grid_desc_m_n; DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc // populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same_v<DLayout, ELayout>);
ds_grid_desc_m_n(i) = ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths, a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, a_g_n_k_wos_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_c_wis_lengths, e_g_n_c_wis_lengths,
e_g_n_c_wis_strides, ds_g_n_c_wis_strides[i],
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
tildes); tildes};
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
});
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// for check validity // for check validity
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
...@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
BElementwiseOp b_element_op_; BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_; CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_; std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_; std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_; std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_; std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_; std::array<index_t, NDimSpatial> input_right_pads_;
}; };
......
...@@ -54,15 +54,16 @@ template <typename GridwiseGemm, ...@@ -54,15 +54,16 @@ template <typename GridwiseGemm,
typename ABDataType, typename ABDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOp,
typename BElementwiseOperation, typename BElementwiseOp,
typename CDEElementwiseOperation, typename CDEElementwiseOp,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap, typename Block2ETileMap,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -73,10 +74,9 @@ __global__ void ...@@ -73,10 +74,9 @@ __global__ void
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOp a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOp b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOp cde_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -84,24 +84,29 @@ __global__ void ...@@ -84,24 +84,29 @@ __global__ void
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = amd_wave_read_first_lane( const long_index_t a_batch_offset =
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset = amd_wave_read_first_lane( const long_index_t b_batch_offset =
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset = amd_wave_read_first_lane( const long_index_t e_batch_offset =
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp; DsPointer p_ds_grid_grp;
...@@ -112,10 +117,10 @@ __global__ void ...@@ -112,10 +117,10 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -130,7 +135,6 @@ __global__ void ...@@ -130,7 +135,6 @@ __global__ void
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -139,6 +143,7 @@ __global__ void ...@@ -139,6 +143,7 @@ __global__ void
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_batch;
ignore = compute_ptr_offset_of_n;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif #endif
} }
...@@ -233,82 +238,54 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -233,82 +238,54 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto transform_conv_to_gemm = using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
TransformConvBwdDataToGemm_v1<NDimSpatial, ConvBackwardDataSpecialization,
ConvBackwardDataSpecialization, AK1,
AK1, BK1,
BK1, MPerBlock,
MPerBlock, NPerBlock,
NPerBlock, KPerBlock,
KPerBlock, DoPadGemmM,
DoPadGemmM, DoPadGemmN,
DoPadGemmN>{}; ALayout,
BLayout,
static auto GetDummyABDsEGridDescriptor() ELayout,
true, /*SplitConvN*/
ABDataType,
EDataType>;
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{ {
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1}; const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1}; const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto ds_grid_desc_m_n = generate_tuple( const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>( using ConvToGemmBwdDataTransformD =
dummy_tensor_lengths, TransformConvBwdDataToGemm_v1<NDimSpatial,
dummy_tensor_strides, ConvBackwardDataSpecialization,
dummy_tensor_lengths, AK1,
dummy_tensor_strides, BK1,
dummy_tensor_lengths, MPerBlock,
dummy_tensor_strides, NPerBlock,
dummy_spatial_lengths, KPerBlock,
dummy_spatial_lengths, DoPadGemmM,
dummy_spatial_lengths, DoPadGemmN,
dummy_spatial_lengths, ALayout,
dummy_spatial_lengths); BLayout,
DLayout,
true, /*SplitConvN*/
ABDataType,
DDataType>;
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
const auto e_grid_desc_m_n = const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
return make_tuple( return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
...@@ -377,7 +354,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -377,7 +354,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
// desc // desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>; using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>; using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
...@@ -431,15 +409,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -431,15 +409,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
...@@ -450,11 +421,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -450,11 +421,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
}); });
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
}); });
...@@ -526,68 +492,65 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -526,68 +492,65 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
throw std::runtime_error("wrong! only implemented for 2D and 3D now"); throw std::runtime_error("wrong! only implemented for 2D and 3D now");
} }
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
conv_N_per_block_ = conv_to_gemm_transform_.N_;
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>( conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>( conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
DsGridDesc_M_N ds_grid_desc_m_n; DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc // populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
ds_grid_desc_m_n(i) = using ConvToGemmBwdDataTransformD =
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>( TransformConvBwdDataToGemm_v1<NDimSpatial,
a_g_n_k_wos_lengths, ConvBackwardDataSpecialization,
a_g_n_k_wos_strides, AK1,
b_g_k_c_xs_lengths, BK1,
b_g_k_c_xs_strides, MPerBlock,
ds_g_n_c_wis_lengths[i], NPerBlock,
ds_g_n_c_wis_strides[i], KPerBlock,
conv_filter_strides, DoPadGemmM,
conv_filter_dilations, DoPadGemmN,
input_left_pads, ALayout,
input_right_pads, BLayout,
tildes); DLayout,
}); true, /*SplitConvN*/
ABDataType,
const auto e_grid_desc_m_n = DDataType>;
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>( ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
a_g_n_k_wos_lengths, a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, a_g_n_k_wos_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_c_wis_lengths, ds_g_n_c_wis_lengths[i],
e_g_n_c_wis_strides, ds_g_n_c_wis_strides[i],
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
tildes); tildes};
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
});
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// desc for problem definition // desc for problem definition
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k =
...@@ -628,6 +591,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -628,6 +591,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} }
} }
} }
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
} }
void Print() const void Print() const
...@@ -660,6 +630,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -660,6 +630,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition // tensor descriptor for problem definition
index_t num_group_; index_t num_group_;
index_t conv_N_per_block_;
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_; std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_; std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_; std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
...@@ -678,23 +649,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -678,23 +649,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op // element-wise op
AElementwiseOp a_element_op_; AElementwiseOp a_element_op_;
BElementwiseOp b_element_op_; BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_; CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_; std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_; std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_; std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_; std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_; std::array<index_t, NDimSpatial> input_right_pads_;
}; };
...@@ -711,8 +675,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -711,8 +675,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.Print(); arg.Print();
} }
float ave_time = 0; const index_t gdy = arg.num_group_;
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdz = num_workgroups_per_Conv_N;
float ave_time = 0;
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
...@@ -724,9 +692,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -724,9 +692,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
throw std::runtime_error("wrong! device_op has invalid setting"); throw std::runtime_error("wrong! device_op has invalid setting");
} }
const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize( const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
arg.e_grid_desc_m_n_container_[i]) * arg.e_grid_desc_m_n_container_[i]);
arg.num_group_;
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
...@@ -747,12 +714,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -747,12 +714,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -762,13 +730,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -762,13 +730,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_k_wos_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_container_[i], arg.a_grid_desc_ak0_m_ak1_container_[i],
arg.b_grid_desc_bk0_n_bk1_container_[i], arg.b_grid_desc_bk0_n_bk1_container_[i],
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.block_2_etile_map_container_[i], arg.block_2_etile_map_container_[i],
arg.compute_ptr_offset_of_batch_); arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_);
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK))
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -13,245 +13,614 @@ ...@@ -13,245 +13,614 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace {
template < template <
index_t NDimSpatial, index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN,
typename ALayout, typename ALayout,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization> typename BLayout,
constexpr auto make_out_grid_desc(const index_t N, typename CLayout,
const index_t Do, bool SplitN = false,
const index_t Ho, typename ADataType = float,
const index_t Wo, typename CDataType = float,
const index_t K, index_t NumGroupsToMerge = 1,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides) typename IndexType = index_t>
struct TransformConvBwdDataToGemm_v1
{ {
const auto KStride = Number<1>{}; private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>) static constexpr auto NonSpatialDimsNum = Number<3>{};
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t HiStride = out_g_n_k_wos_strides[3];
const index_t WiStride = out_g_n_k_wos_strides[4];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), static constexpr auto DIdx = NonSpatialDimsNum;
make_tuple(WiStride, KStride)); static constexpr auto HIdx =
} NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
else static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = NonSpatialDimsNum;
static constexpr auto YIdx =
NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{ {
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K), acc +=
make_tuple(NStride, HiStride, WiStride, KStride)); static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
} }
return acc;
} }
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
template <typename ConvDimsType>
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides)
{ {
const index_t NStride = out_g_n_k_wos_strides[1]; const long_index_t a_element_space_size =
const index_t DoStride = out_g_n_k_wos_strides[3]; calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1);
const index_t HoStride = out_g_n_k_wos_strides[4]; const long_index_t c_element_space_size =
const index_t WoStride = out_g_n_k_wos_strides[5]; calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1);
if constexpr(ConvBwdDataSpecialization == const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: c_element_space_size * sizeof(CDataType));
Filter1x1Stride1Pad0) constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const IndexType N = a_g_n_k_wos_lengths[I1];
if(element_space_size > TwoGB)
{ {
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), if(divisor <= static_cast<double>(N))
make_tuple(WoStride, KStride)); {
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
} }
else else
{ {
return make_naive_tensor_descriptor( // Split N is not needed.
make_tuple(N, Do, Ho, Wo, K), return N;
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
} }
} }
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
public:
__host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {}
template <typename TransformConvBwdDataToGemm_v1Base>
__host__ __device__ TransformConvBwdDataToGemm_v1(
const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base)
: N_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.N_)},
Di_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wi_)},
Do_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Do_)},
Ho_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Ho_)},
Wo_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wo_)},
Z_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Z_)},
Y_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Y_)},
X_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.X_)},
K_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.K_)},
C_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.C_)},
DiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DiStride_)},
HiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HiStride_)},
WiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WiStride_)},
DoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DoStride_)},
HoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HoStride_)},
WoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WoStride_)},
CStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)},
CStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)},
KStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)},
KStrideTensorB_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)},
NStrideTensorA_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)},
NStrideTensorC_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)},
ConvStrideD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)},
ConvStrideH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)},
ConvStrideW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)},
ConvDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)},
ConvDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)},
ConvDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)},
InLeftPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)},
InLeftPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)},
InLeftPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadW_)},
IdxZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)},
IdxYTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)},
IdxXTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)},
GcdStrideDilationD_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)},
GcdStrideDilationH_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)},
GcdStrideDilationW_{
static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)},
ZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZTilde_)},
YTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YTilde_)},
XTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XTilde_)},
DTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DTilde_)},
HTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HTilde_)},
WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)}
{ {
// assume packed }
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: template <typename ConvDimsType, typename ConvSpatialDimsType>
Filter1x1Stride1Pad0) __host__ __device__
TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_k_wos_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_c_wis_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads,
const ConvSpatialDimsType& tildes)
: Hi_{c_g_n_c_wis_lengths[HIdx]},
Wi_{c_g_n_c_wis_lengths[WIdx]},
Ho_{a_g_n_k_wos_lengths[HIdx]},
Wo_{a_g_n_k_wos_lengths[WIdx]},
Y_{b_g_k_c_xs_lengths[YIdx]},
X_{b_g_k_c_xs_lengths[XIdx]},
K_{a_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
HiStride_{c_g_n_c_wis_strides[HIdx]},
WiStride_{c_g_n_c_wis_strides[WIdx]},
HoStride_{a_g_n_k_wos_strides[HIdx]},
WoStride_{a_g_n_k_wos_strides[WIdx]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
CStrideTensorC_{c_g_n_c_wis_strides[I2]},
KStrideTensorA_{a_g_n_k_wos_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
NStrideTensorA_{a_g_n_k_wos_strides[I1]},
NStrideTensorC_{c_g_n_c_wis_strides[I1]},
ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]},
ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]},
ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]},
ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]},
InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]},
InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]},
InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); N_ = GetSplitedNSize(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides);
} }
else else
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); N_ = c_g_n_c_wis_lengths[I1];
} }
} if constexpr(NDimSpatial == 3)
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); Di_ = c_g_n_c_wis_lengths[DIdx];
Do_ = a_g_n_k_wos_lengths[DIdx];
Z_ = b_g_k_c_xs_lengths[ZIdx];
DiStride_ = c_g_n_c_wis_strides[DIdx];
DoStride_ = a_g_n_k_wos_strides[DIdx];
ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum];
ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum];
InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum];
InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum];
IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum];
GcdStrideDilationD_ = math::gcd(ConvStrideD_, ConvDilationD_);
ZTilde_ = ConvStrideD_ / GcdStrideDilationD_;
DTilde_ = Do_ + math::integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_);
ZDot_ = math::integer_divide_ceil(Z_, ZTilde_);
} }
else else
{ {
return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1;
InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0;
} }
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
template <typename BLayout> GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_);
constexpr auto make_wei_grid_desc( GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_);
const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C)
{
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>) YTilde_ = ConvStrideH_ / GcdStrideDilationH_;
{ XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
template <index_t NDimSpatial, typename CLayout>
constexpr auto make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides)
{
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> || HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_);
is_same_v<CLayout, tensor_layout::convolution::NHWGC> || WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_);
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{ YDot_ = math::integer_divide_ceil(Y_, YTilde_);
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), XDot_ = math::integer_divide_ceil(X_, XTilde_);
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
} }
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>) #if 0 // At now not supported to split tensor
__host__ bool AreDescriptorsSmallerThan2GB() const
{ {
return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), constexpr long_index_t TwoGB = (long_index_t{1} << 31);
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3], const long_index_t in_desc_space_size =
in_g_n_c_wis_strides[4], I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
in_g_n_c_wis_strides[5], (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_;
in_g_n_c_wis_strides[2])); const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_;
bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
} }
else
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{ {
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name()); // Create copies
} auto conv_to_gemm_transformer_left = *this;
} auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = (Do_ / 2) * DoStride_;
c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
} // namespace conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
template < conv_to_gemm_transformer_left.InRightPadH_ = 0;
index_t NDimSpatial, conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
index_t AK1,
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto NonSpatialDimsNum = Number<3>{}; conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = (Ho_ / 2) * HoStride_;
c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
static constexpr auto DIdx = Number<NonSpatialDimsNum>{}; conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
static constexpr auto HIdx = conv_to_gemm_transformer_right.InLeftPadW_ = 0;
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{}; conv_to_gemm_transformer_left.InRightPadW_ = 0;
static constexpr auto YIdx = conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ALayout, conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && conv_to_gemm_transformer_right.Wi_ =
(is_same_v<ALayout, tensor_layout::convolution::GNHWK> || math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> || (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>), a_right_offset = (Wo_ / 2) * WoStride_;
bool>::type = false> c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
static auto MakeADescriptor_AK0_M_AK1( }
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths, // Return left transform, right transformer, right offset to Input and right offset to
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides, // Output
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths, return ck::make_tuple(conv_to_gemm_transformer_left,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */, conv_to_gemm_transformer_right,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths, a_grid_ptr_base + a_right_offset,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */, c_grid_ptr_base + c_right_offset);
const std::array<index_t, NDimSpatial>& conv_filter_strides, }
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
const std::array<index_t, NDimSpatial>& /* input_right_pads */, CDataType* c_grid_ptr_base) const
const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; // Create copies
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; auto conv_to_gemm_transformer_left = *this;
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate start position in input for right tensor
const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_);
const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_);
const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_);
// Calculate last position in input for left tensor
const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_);
const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_);
const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_);
if(Di_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Di_ = Di_ / 2;
conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx;
conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = do_right_transformer_start_idx * DoStride_;
c_right_offset = (Di_ / 2) * DiStride_;
}
else if(Hi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Hi_ = Hi_ / 2;
conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
// Calculate new input size
conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ;
conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ;
;
// Calcualte offsets
a_right_offset = ho_right_transformer_start_idx * HoStride_;
c_right_offset = (Hi_ / 2) * HiStride_;
}
else if(Wi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Wi_ = Wi_ / 2;
conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
// Calculate new input size
conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx;
conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = wo_right_transformer_start_idx * WoStride_;
c_right_offset = (Wi_ / 2) * WiStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
#endif
const index_t N = in_g_n_c_wis_lengths[1]; __host__ __device__ auto MakeOutGridDesc() const
const index_t K = wei_g_k_c_xs_lengths[1]; {
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
const index_t Hi = in_g_n_c_wis_lengths[HIdx]; make_tuple(WoStride_, KStrideTensorA_));
const index_t Wi = in_g_n_c_wis_lengths[WIdx]; }
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
const index_t Ho = out_g_n_k_wos_lengths[HIdx]; make_tuple(WoStride_, KStrideTensorA_));
const index_t Wo = out_g_n_k_wos_lengths[WIdx]; }
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, K_),
make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N_ * Do_ * Ho_ * Wo_, K_));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_));
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; __host__ __device__ auto MakeWeiGridDesc() const
const index_t Y = wei_g_k_c_xs_lengths[YIdx]; {
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; {
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; __host__ __device__ auto MakeInGridDesc() const
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; {
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_));
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
template <
typename ALayout_ = ALayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout_, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout_, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout_, tensor_layout::convolution::NDHWGK>),
bool>::type = false>
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
{
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc = const auto out_grid_desc = MakeOutGridDesc();
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
N, Do, Ho, Wo, K, out_g_n_k_wos_strides);
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t AK0 = math::integer_divide_ceil(K, AK1); const index_t AK0 = math::integer_divide_ceil(K_, AK1);
// A: output tensor // A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0, AK1))), make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{})); make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
...@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor // only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor( const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min( const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
{ {
// A: output tensor // A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo, I0, I0), make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N_),
make_embed_transform(make_tuple(YDot, HTilde), make_embed_transform(make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot, WTilde), make_embed_transform(make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1
// A: output tensor // A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc, out_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Do, I0, I0), make_pad_transform(Do_, I0, I0),
make_pad_transform(Ho, I0, I0), make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo, I0, I0), make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc, out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform( make_embed_transform(
make_tuple(ZDot, DTilde), make_tuple(ZDot_, DTilde_),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)),
make_embed_transform( make_embed_transform(
make_tuple(YDot, HTilde), make_tuple(YDot_, HTilde_),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform( make_embed_transform(
make_tuple(XDot, WTilde), make_tuple(XDot_, WTilde_),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(
make_slice_transform(ZDot, I0, ZDotSlice), make_pass_through_transform(N_),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_pass_through_transform(K)), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))), make_merge_transform(
make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1
} }
} }
template <typename BLayout, template <typename BLayout_ = BLayout,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<BLayout, tensor_layout::convolution::GKYXC> || (is_same_v<BLayout_, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>), is_same_v<BLayout_, tensor_layout::convolution::GKZYXC>),
bool>::type = false> bool>::type = false>
static auto MakeBDescriptor_BK0_N_BK1( __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& /* input_left_pads */,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume packed // assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d // k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C); const auto wei_grid_desc = MakeWeiGridDesc();
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const index_t BK0 = math::integer_divide_ceil(K, BK1); const index_t BK0 = math::integer_divide_ceil(K_, BK1);
// B: weight tensor // B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor( ck::tensor_operation::device::PadTensorDescriptor(
...@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
// GemmK is different for each GEMM // GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
// B weight tensor // B weight tensor
if constexpr(NDimSpatial == 2) if constexpr(NDimSpatial == 2)
...@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_grid_desc, wei_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(K), make_pass_through_transform(K_),
make_embed_transform(make_tuple(YDot, YTilde), make_embed_transform(make_tuple(YDot_, YTilde_),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_embed_transform(make_tuple(XDot, XTilde), make_embed_transform(make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K_),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_grid_desc, wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_grid_desc, wei_grid_desc,
make_tuple( make_tuple(make_pass_through_transform(K_),
make_pass_through_transform(K), make_embed_transform(
make_embed_transform(make_tuple(ZDot, ZTilde), make_tuple(ZDot_, ZTilde_),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)), make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)),
make_embed_transform(make_tuple(YDot, YTilde), make_embed_transform(
make_tuple(ConvStrideH / GcdStrideDilationH, I1)), make_tuple(YDot_, YTilde_),
make_embed_transform(make_tuple(XDot, XTilde), make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)), make_embed_transform(
make_pass_through_transform(C)), make_tuple(XDot_, XTilde_),
make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_tuple(make_pass_through_transform(K_),
make_slice_transform(ZDot, I0, ZDotSlice), make_slice_transform(ZDot_, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(YDot_, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice), make_slice_transform(XDot_, I0, XDotSlice),
make_freeze_transform(i_ztilde), make_freeze_transform(IdxZTilde_),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
...@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), make_tuple(
make_pass_through_transform(C)), make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1
} }
} }
template <typename CLayout, template <
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && typename CLayout_ = CLayout,
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> || typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
is_same_v<CLayout, tensor_layout::convolution::GNDHWC> || (is_same_v<CLayout_, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> || is_same_v<CLayout_, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC> || is_same_v<CLayout_, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>), is_same_v<CLayout_, tensor_layout::convolution::NDHWGC> ||
bool>::type = false> is_same_v<CLayout_, tensor_layout::convolution::G_NHW_C>),
static auto bool>::type = false>
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths, __host__ __device__ auto MakeCDescriptor_M_N() const
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const std::array<index_t, NDimSpatial>& tildes)
{ {
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum];
const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum];
const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume strided // assume strided
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d // n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const auto in_grid_desc = const auto in_grid_desc = MakeInGridDesc();
make_in_grid_desc<NDimSpatial, CLayout>(N, Di, Hi, Wi, C, in_g_n_c_wis_strides);
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
...@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
...@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1
in_n_y_ho_x_wo_c_grid_desc, in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0), make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)), make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
...@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple( make_tuple(
make_pass_through_transform(N), make_pass_through_transform(N_),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(make_freeze_transform(I0), make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_freeze_transform(I0), make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, make_tuple(Sequence<1>{},
Sequence<3>{}, Sequence<3>{},
Sequence<5>{}, Sequence<5>{},
...@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1
} }
else else
{ {
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on DTilde, HTilde and WTilde that contribute to // only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor // non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor( const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min( const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
...@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1
{ {
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc, in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(YTilde, HTilde), make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde, WTilde), make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc, in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1
{ {
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc, in_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc, in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(ZTilde, DTilde), make_embed_transform(make_tuple(ZTilde_, DTilde_),
make_tuple(ConvDilationD, ConvStrideD)), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(YTilde, HTilde), make_embed_transform(make_tuple(YTilde_, HTilde_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(XTilde, WTilde), make_embed_transform(make_tuple(XTilde_, WTilde_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor( transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_freeze_transform(i_ztilde), make_freeze_transform(IdxZTilde_),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde), make_freeze_transform(IdxYTilde_),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde), make_freeze_transform(IdxXTilde_),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1
} }
// for input bias // for input bias
template <typename CLayout, template <typename CLayout_ = CLayout,
typename std::enable_if<NDimSpatial == 2 && typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GC> || (is_same_v<CLayout_, tensor_layout::convolution::GC> ||
is_same_v<CLayout, tensor_layout::convolution::G_C>), is_same_v<CLayout_, tensor_layout::convolution::G_C>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ auto MakeCDescriptor_M_N() const
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* in_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& /* tildes */)
{ {
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
if constexpr(ConvBwdDataSpecialization == if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0) Filter1x1Stride1Pad0)
{ {
const auto in_gemmm_gemmn_grid_desc = const auto in_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
return in_gemmm_gemmn_grid_desc; return in_gemmm_gemmn_grid_desc;
} }
else else
{ {
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor // only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor( const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor( const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IHTildeSliceEnd = math::min( const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min( const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// bias tensor // bias tensor
const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor( const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor(
make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1)); make_tuple(N_ * HTildeSlice * WTildeSlice, C_), make_tuple(I0, I1));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc, in_gemmmraw_gemmnraw_grid_desc,
...@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1 ...@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1
return in_gemmm_gemmn_grid_desc; return in_gemmm_gemmn_grid_desc;
} }
} }
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;
IndexType K_, C_;
IndexType DiStride_, HiStride_, WiStride_;
IndexType DoStride_, HoStride_, WoStride_;
IndexType CStrideTensorB_, CStrideTensorC_, KStrideTensorA_, KStrideTensorB_;
IndexType NStrideTensorA_, NStrideTensorC_;
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_;
IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_;
IndexType ZTilde_, YTilde_, XTilde_;
IndexType DTilde_, HTilde_, WTilde_;
IndexType ZDot_, YDot_, XDot_;
}; };
} // namespace tensor_operation } // namespace tensor_operation
......
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_data_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif() endif()
add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template <typename Tuple>
class TestGroupedConvndBwdDataWmma : public ::testing::Test
{
protected:
using DataType = std::tuple_element_t<0, Tuple>;
using OutLayout = std::tuple_element_t<1, Tuple>;
using WeiLayout = std::tuple_element_t<2, Tuple>;
using InLayout = std::tuple_element_t<3, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params;
template <ck::index_t NDimSpatial>
void Run()
{
EXPECT_FALSE(conv_params.empty());
bool pass = true;
for(auto& param : conv_params)
{
pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
OutLayout,
WeiLayout,
InLayout,
DataType,
DataType,
DataType>(
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
param);
}
EXPECT_TRUE(pass);
}
};
using namespace ck::tensor_layout::convolution;
using KernelTypes2d = ::testing::Types<std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
using KernelTypes3d = ::testing::Types<std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
template <typename Tuple>
class TestGroupedConvndBwdDataWmma2d : public TestGroupedConvndBwdDataWmma<Tuple>
{
};
template <typename Tuple>
class TestGroupedConvndBwdDataWmma3d : public TestGroupedConvndBwdDataWmma<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdDataWmma2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndBwdDataWmma3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" #include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdData : public ::testing::Test class TestGroupedConvndBwdDataXdl : public ::testing::Test
{ {
protected: protected:
using DataType = std::tuple_element_t<0, Tuple>; using DataType = std::tuple_element_t<0, Tuple>;
...@@ -51,35 +51,31 @@ using namespace ck::tensor_layout::convolution; ...@@ -51,35 +51,31 @@ using namespace ck::tensor_layout::convolution;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>, using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>, std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>, std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
std::tuple<float, NHWGK, GKYXC, NHWGC>, std::tuple<float, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>, std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>, std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>, using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>, std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>, std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>, std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>, std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>, std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple> class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl<Tuple>
{ {
}; };
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData<Tuple> class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl<Tuple>
{ {
}; };
TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
{ {
this->conv_params.clear(); this->conv_params.clear();
...@@ -94,10 +90,13 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) ...@@ -94,10 +90,13 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
// SplitN case
this->conv_params.push_back(
{2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->template Run<2>();
} }
TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D)
{ {
this->conv_params.clear(); this->conv_params.clear();
this->conv_params.push_back( this->conv_params.push_back(
...@@ -112,5 +111,17 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) ...@@ -112,5 +111,17 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// SplitN case
this->conv_params.push_back({3,
1,
128,
4,
192,
{2, 2, 2},
{2, 224, 224},
{1, 224, 224},
{1, 1, 1},
{0, 0, 0},
{0, 0, 0}});
this->template Run<3>(); this->template Run<3>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment