Unverified Commit 4ec5c52a authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add Grouped Conv Fwd Large Tensor kernel (#1432)

* Support 64 bit indexing

* Add new grouped conv fwd kernel for large tensors

* Add instances large tensor

* Fixes for transform conv to gemm

* Fixes

* fixes

* Remove not needed instances

* examples fixes

* Remove not need ds arrays

* Fix tests

* Add 2GB check in gridwise dl

* Fixes
parent 7f57b2e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cassert>
......@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline HostTensorDescriptor
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
{
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_};
std::vector<ck::long_index_t> dimensions{problem_size.G_, problem_size.N_};
ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
......@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
// reset input to zero
in_device_buf.SetZero();
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
// do GEMM
auto conv = DeviceConvNdBwdDataInstance{};
auto invoker = conv.MakeInvoker();
......@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.GetOutputSpatialLengths(),
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
static_cast<ck::index_t>(conv_param.N_),
static_cast<ck::index_t>(conv_param.K_),
static_cast<ck::index_t>(conv_param.C_),
input_spatial_lengths_i32,
filter_spatial_lengths_i32,
output_spatial_lengths_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
in_element_op,
wei_element_op,
out_element_op);
......
......@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_a,
BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
......@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
......@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
b_element_op,
cde_element_op};
}
static __device__ __host__ auto MakeArgument(
APointers p_as,
BPointers p_bs,
const ck::Array<const void*, NumDTensor>& p_ds,
void* p_e,
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<long_index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<long_index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<long_index_t, NDimSpatial>& input_left_pads,
const ck::Array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
};
} // namespace device
......
......@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
static constexpr auto spatial_offset = Number<3>{};
using GemmToConvFwdTransformer =
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
......@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
: independent_filter_stride;
}
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
......
......@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
......@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
......@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
template <typename CLay>
static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
......@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
......@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
......
......@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ALayout,
......@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
......@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
......@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op};
}
static auto
MakeArgument(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
APointers p_a,
BPointers p_b,
APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths,
......@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
......@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
......@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
......@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
......@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
......@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false;
}
// Gridwise gemm v3 doesn't verify descriptors size
if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB())
{
return false;
}
// check Gridwise GEMM
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
......@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op};
}
static auto
MakeArgument(const void* p_as,
const void* p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
......@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
......
......@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds =
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
using BGridDesc =
......@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
......@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <queue>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
template <typename GridwiseGemm,
index_t MaxGemmsNum,
typename GemmArgs,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputePtrOffset,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle(
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args,
const index_t gemms_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
index_t left = 0;
index_t right = gemms_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ &&
block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) &&
left <= right)
{
if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset,
gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset,
Tuple<>{},
gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
Tuple<>{},
gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_kernel_args[group_id].block_2_etile_map_);
#else
ignore = gemm_desc_kernel_args;
ignore = gemms_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_ptr_offset_of_groups;
ignore = compute_ptr_offset_of_n;
#endif
}
} // namespace
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
AComputeDataType,
BComputeDataType>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t MaxGemmsNum = 32;
static_assert(NumDTensor == 0, "MultiD not supported.");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ConvToGemmFwdTransformerIndexT = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
I1,
index_t>;
using ConvToGemmFwdTransformerLongIndexT = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
I1,
long_index_t>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
// desc for problem definition
constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
static auto
GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base,
const ADataType* a_grid_ptr_base,
EDataType* c_grid_ptr_base)
{
// Max number of splits
// We need to use it to avoid infinity loop
constexpr index_t max_split_numbers = MaxGemmsNum / 2;
// Arrays to store transformers with smaller descs than 2GB
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformers_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs_arr;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs_arr;
// Queue for spliting
std::queue<ConvToGemmFwdTransformerLongIndexT> conv_to_gemm_transformers_queue(
{conv_to_gemm_transformer_base});
std::queue<const ADataType*> a_grid_ptrs_queue({a_grid_ptr_base});
std::queue<EDataType*> c_grid_ptrs_queue({c_grid_ptr_base});
index_t gemms_number = 0;
index_t split_numbers = 0;
// Algorithm:
// While queue is not empty:
// 1. Get transformer from queue.
// 2. If descs are smaller than 2GB push to result array.
// 3. If descs are bigger than 2GB split into left and right transformer.
// and push the both into the queue.
while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers &&
gemms_number < MaxGemmsNum)
{
// Get transformer from the queue
const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front();
const ADataType* a_grid_ptr = a_grid_ptrs_queue.front();
EDataType* c_grid_ptr = c_grid_ptrs_queue.front();
// Check if convolution not exceed 2GB
if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB())
{
// If yes, push into result array
conv_to_gemm_transformers_arr(gemms_number) =
ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer};
a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
gemms_number++;
}
else
{
// If no, split into left and right convolutions
ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part;
const ADataType* a_grid_right_ptr;
EDataType* c_grid_right_ptr;
ck::tie(conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part,
a_grid_right_ptr,
c_grid_right_ptr) =
conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part);
// Left offsets remain the same
a_grid_ptrs_queue.push(a_grid_ptr);
a_grid_ptrs_queue.push(a_grid_right_ptr);
c_grid_ptrs_queue.push(c_grid_ptr);
c_grid_ptrs_queue.push(c_grid_right_ptr);
split_numbers++;
}
// Remove from the queue
conv_to_gemm_transformers_queue.pop();
a_grid_ptrs_queue.pop();
c_grid_ptrs_queue.pop();
}
const bool is_split_valid = conv_to_gemm_transformers_queue.empty();
return ck::make_tuple(conv_to_gemm_transformers_arr,
a_grid_ptrs_arr,
c_grid_ptrs_arr,
gemms_number,
is_split_valid);
}
#define GridwiseGemmTemplateParameters \
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
AComputeDataType
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Structure for each gemm(conv)
struct GemmArgs
{
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
EDataType* e_ptr_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
};
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& /*p_ds*/,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_lengths*/,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_strides*/,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: num_group_{static_cast<index_t>(a_g_n_c_wis_lengths[0])},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// Perform grouped gemm, generate array of tranformer for convolution
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
ck::tie(conv_to_gemm_transformer_arr,
a_grid_ptrs,
c_grid_ptrs,
gemms_count_,
is_split_valid_) =
GenerateConvToGemmTransforms(
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
static_cast<const ADataType*>(p_a),
static_cast<EDataType*>(p_e));
grid_size_ = 0;
valid_gemms_count_ = 0;
if(is_split_valid_)
{
// Create GemmArg for each gemm(conv)
for(index_t i = 0; i < gemms_count_; i++)
{
const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
conv_to_gemm_transformer_arr[i])};
const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
conv_to_gemm_transformer_arr[i])};
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
const auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const index_t grid_size_grp =
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
Tuple<>{},
e_grid_desc_m_n,
block_2_etile_map))
{
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
c_grid_ptrs[i],
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd};
valid_gemms_count_++;
}
}
// N is the same for all convs
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
}
// Strides for G and N remain the same
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
}
void Print() const
{
for(index_t i = 0; i < valid_gemms_count_; i++)
{
std::cout << "A[AK0, M, AK1]: " << gemm_desc_kernel_args_[i].a_grid_desc_ak0_m_ak1_
<< std::endl;
std::cout << "B[BK0, N, BK1]: " << gemm_desc_kernel_args_[i].b_grid_desc_bk0_n_bk1_
<< std::endl;
std::cout
<< "E[MBlock, MPerBlock, NBlock, NPerBlock]: "
<< gemm_desc_kernel_args_[i].e_grid_desc_mblock_mperblock_nblock_nperblock_
<< std::endl;
}
}
index_t num_group_;
index_t conv_N_per_block_;
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args_;
index_t grid_size_;
index_t gemms_count_;
index_t valid_gemms_count_;
bool is_split_valid_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<long_index_t, NDimSpatial> conv_filter_strides_;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_;
std::array<long_index_t, NDimSpatial> input_left_pads_;
std::array<long_index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.grid_size_;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
// K is constant for all gemms
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
GridwiseGemm,
MaxGemmsNum,
GemmArgs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_args_,
arg.gemms_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
// Check if all descs are valid
if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_))
{
return false;
}
// check device
if(get_device_name() == "gfx908")
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
return false;
}
}
if(!ck::is_xdl_supported())
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of E
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
std::array<long_index_t, NDimSpatial> input_left_pads_i64;
std::array<long_index_t, NDimSpatial> input_right_pads_i64;
array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i64, conv_filter_strides);
array_convert(conv_filter_dilations_i64, conv_filter_dilations);
array_convert(input_left_pads_i64, input_left_pads);
array_convert(input_right_pads_i64, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i64,
a_g_n_c_wis_strides_i64,
b_g_k_c_xs_lengths_i64,
b_g_k_c_xs_strides_i64,
ds_g_n_k_wos_lengths_i64,
ds_g_n_k_wos_strides_i64,
e_g_n_k_wos_lengths_i64,
e_g_n_k_wos_strides_i64,
conv_filter_strides_i64,
conv_filter_dilations_i64,
input_left_pads_i64,
input_right_pads_i64,
a_element_op,
b_element_op,
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
std::array<long_index_t, NDimSpatial> input_left_pads_i64;
std::array<long_index_t, NDimSpatial> input_right_pads_i64;
array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i64, conv_filter_strides);
array_convert(conv_filter_dilations_i64, conv_filter_dilations);
array_convert(input_left_pads_i64, input_left_pads);
array_convert(input_right_pads_i64, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i64,
a_g_n_c_wis_strides_i64,
b_g_k_c_xs_lengths_i64,
b_g_k_c_xs_strides_i64,
ds_g_n_k_wos_lengths_i64,
ds_g_n_k_wos_strides_i64,
e_g_n_k_wos_lengths_i64,
e_g_n_k_wos_strides_i64,
conv_filter_strides_i64,
conv_filter_dilations_i64,
input_left_pads_i64,
input_right_pads_i64,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using GemmToConvFwdTransformer =
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
......@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N;
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
{
return false;
}
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
{
return false;
}
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
......
......@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1>
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvFwdToGemm
{
private:
......@@ -46,10 +47,10 @@ struct TransformConvFwdToGemm
}
template <typename ConvDimsType>
static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
......@@ -59,7 +60,7 @@ struct TransformConvFwdToGemm
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const index_t N = a_g_n_c_wis_lengths[I1];
const IndexType N = a_g_n_c_wis_lengths[I1];
if(element_space_size > TwoGB)
{
......@@ -70,7 +71,7 @@ struct TransformConvFwdToGemm
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
......@@ -98,6 +99,53 @@ struct TransformConvFwdToGemm
public:
__host__ __device__ constexpr TransformConvFwdToGemm() {}
template <typename TransformConvFwdToGemmBase>
__host__ __device__
TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
: N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
DiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DiStride_)},
HiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HiStride_)},
WiStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WiStride_)},
DoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.DoStride_)},
HoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.HoStride_)},
WoStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.WoStride_)},
XStride_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.XStride_)},
CStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorA_)},
CStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.CStrideTensorB_)},
KStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorB_)},
KStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.KStrideTensorC_)},
NStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorA_)},
NStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.NStrideTensorC_)},
GStrideTensorA_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorA_)},
GStrideTensorB_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorB_)},
GStrideTensorC_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.GStrideTensorC_)},
ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
{
}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
......@@ -126,6 +174,8 @@ struct TransformConvFwdToGemm
DiStride_{I1},
HiStride_{I1},
WiStride_{a_g_n_c_wis_strides[I3]},
DoStride_{I1},
HoStride_{I1},
WoStride_{c_g_n_k_wos_strides[I3]},
XStride_{b_g_k_c_xs_strides[I3]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
......@@ -133,6 +183,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
......@@ -150,10 +201,10 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I0]},
ZYX_{X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
......@@ -164,7 +215,6 @@ struct TransformConvFwdToGemm
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Wo_;
}
template <typename ConvDimsType,
......@@ -195,6 +245,8 @@ struct TransformConvFwdToGemm
DiStride_{I1},
HiStride_{a_g_n_c_wis_strides[I3]},
WiStride_{a_g_n_c_wis_strides[I4]},
DoStride_{I1},
HoStride_{c_g_n_k_wos_strides[I3]},
WoStride_{c_g_n_k_wos_strides[I4]},
XStride_{b_g_k_c_xs_strides[I4]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
......@@ -202,6 +254,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
......@@ -219,10 +272,10 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
......@@ -233,7 +286,6 @@ struct TransformConvFwdToGemm
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Ho_ * Wo_;
}
template <typename ConvDimsType,
......@@ -264,6 +316,8 @@ struct TransformConvFwdToGemm
DiStride_{a_g_n_c_wis_strides[I3]},
HiStride_{a_g_n_c_wis_strides[I4]},
WiStride_{a_g_n_c_wis_strides[I5]},
DoStride_{c_g_n_k_wos_strides[I3]},
HoStride_{c_g_n_k_wos_strides[I4]},
WoStride_{c_g_n_k_wos_strides[I5]},
XStride_{b_g_k_c_xs_strides[I5]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
......@@ -271,6 +325,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
NStrideTensorC_{c_g_n_k_wos_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
......@@ -288,10 +343,10 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
if constexpr(SplitN)
{
......@@ -302,7 +357,122 @@ struct TransformConvFwdToGemm
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Do_ * Ho_ * Wo_;
}
__host__ bool AreDescriptorsSmallerThan2GB() const
{
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t in_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
}
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
// Create copies
auto conv_to_gemm_transformer_left = *this;
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
c_right_offset = (Do_ / 2) * DoStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
c_right_offset = (Ho_ / 2) * HoStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
conv_to_gemm_transformer_right.Wi_ =
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
c_right_offset = (Wo_ / 2) * WoStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
......@@ -320,20 +490,27 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(N_, Wo_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
......@@ -527,20 +704,29 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
......@@ -759,20 +945,34 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
......@@ -1119,45 +1319,70 @@ struct TransformConvFwdToGemm
}
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_));
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 3 &&
(is_same_v<CLayout, tensor_layout::convolution::G_K>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNWK>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const IndexType NDoHoWo = N_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_),
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1),
make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_pass_through_transform(NDoHoWo_),
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
......@@ -1167,7 +1392,7 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo_),
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
......@@ -1175,45 +1400,146 @@ struct TransformConvFwdToGemm
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
index_t NDimSp = NDimSpatial,
typename std::enable_if<
NDimSp == 2 && (is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_));
return out_gemmm_gemmn_desc;
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc =
make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(NStrideTensorC_,
HoStride_,
WoStride_,
GStrideTensorC_,
KStrideTensorC_,
GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
public:
index_t N_;
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<
NDimSp == 3 && (is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
private:
const index_t Di_, Hi_, Wi_;
const index_t Do_, Ho_, Wo_;
const index_t Z_, Y_, X_;
const index_t K_, C_;
const index_t DiStride_, HiStride_, WiStride_;
const index_t WoStride_;
const index_t XStride_;
const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
const index_t NStrideTensorA_;
const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_;
const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_;
const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_;
const index_t InRightPadD_, InRightPadH_, InRightPadW_;
const index_t ZYX_;
index_t NDoHoWo_;
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc =
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(NStrideTensorC_,
DoStride_,
HoStride_,
WoStride_,
GStrideTensorC_,
KStrideTensorC_,
GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;
IndexType K_, C_;
IndexType DiStride_, HiStride_, WiStride_;
IndexType DoStride_, HoStride_, WoStride_;
IndexType XStride_;
IndexType CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
IndexType NStrideTensorA_, NStrideTensorC_;
IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType ZYX_;
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
......@@ -1230,17 +1556,17 @@ struct TransformConv
if(NDimSpatial == 2)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>();
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK, 2>();
}
else if(NDimSpatial == 3)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>();
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK, 3>();
}
else if(NDimSpatial == 1)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>();
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK, 1>();
}
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
namespace ck {
// Define the common macro for gfx94x models
......@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
}
template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
......@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<long_index_t> conv_strides_;
std::vector<long_index_t> conv_dilations_;
std::vector<long_index_t> in_left_pads_;
std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
std::vector<long_index_t> filter_spatial_lengths_;
std::vector<long_index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
{
constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
for(ck::long_index_t i = 0; i < NDimSpatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
......@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
const long_index_t G = arg.output_.GetLengths()[0];
const long_index_t N = arg.output_.GetLengths()[1];
const long_index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) {
for(long_index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
index_t column = 0;
long_index_t row = n * Wo + wo;
long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
......@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
const long_index_t Ho = arg.output_spatial_lengths_[0];
const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
for(long_index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
for(long_index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
......@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
const long_index_t Do = arg.output_spatial_lengths_[0];
const long_index_t Ho = arg.output_spatial_lengths_[1];
const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
for(long_index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
for(long_index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
for(long_index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho *
......@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2];
++x)
{
auto wi =
static_cast<ck::long_index_t>(
......@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
......@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2];
const ck::long_index_t G = arg.output_.GetLengths()[0];
const ck::long_index_t N = arg.output_.GetLengths()[1];
const ck::long_index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
const long_index_t NDoHoWo =
N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
const long_index_t CZYX =
C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
......@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads)
{
return Argument{input,
output,
......
......@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<long_index_t> conv_strides_;
std::vector<long_index_t> conv_dilations_;
std::vector<long_index_t> in_left_pads_;
std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
......@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
......@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<long_index_t> conv_strides_;
std::vector<long_index_t> conv_dilations_;
std::vector<long_index_t> in_left_pads_;
std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
......@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment