Unverified Commit 4ec5c52a authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add Grouped Conv Fwd Large Tensor kernel (#1432)

* Support 64 bit indexing

* Add new grouped conv fwd kernel for large tensors

* Add instances large tensor

* Fixes for transform conv to gemm

* Fixes

* fixes

* Remove not needed instances

* examples fixes

* Remove not need ds arrays

* Fix tests

* Add 2GB check in gridwise dl

* Fixes
parent 7f57b2e0
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc, ...@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline HostTensorDescriptor inline HostTensorDescriptor
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size) make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
{ {
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_}; std::vector<ck::long_index_t> dimensions{problem_size.G_, problem_size.N_};
ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions)); ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification, ...@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
// reset input to zero // reset input to zero
in_device_buf.SetZero(); in_device_buf.SetZero();
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
// do GEMM // do GEMM
auto conv = DeviceConvNdBwdDataInstance{}; auto conv = DeviceConvNdBwdDataInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
...@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification, ...@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_, static_cast<ck::index_t>(conv_param.N_),
conv_param.K_, static_cast<ck::index_t>(conv_param.K_),
conv_param.C_, static_cast<ck::index_t>(conv_param.C_),
conv_param.input_spatial_lengths_, input_spatial_lengths_i32,
conv_param.filter_spatial_lengths_, filter_spatial_lengths_i32,
conv_param.GetOutputSpatialLengths(), output_spatial_lengths_i32,
conv_param.conv_filter_strides_, conv_filter_strides_i32,
conv_param.conv_filter_dilations_, conv_filter_dilations_i32,
conv_param.input_left_pads_, input_left_pads_i32,
conv_param.input_right_pads_, input_right_pads_i32,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
......
...@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator ...@@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0; const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_a,
BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>; using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
__host__ __device__ static auto __host__ __device__ static auto
MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
...@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay> template <typename BLay>
__host__ __device__ static auto __host__ __device__ static auto
MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(); conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
...@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay> template <typename ELay>
__host__ __device__ static auto __host__ __device__ static auto
MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
...@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different. // Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast. // Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K = using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>; remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K = using BGridDesc_N_K =
...@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride // D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_; ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
...@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
b_element_op, b_element_op,
cde_element_op}; cde_element_op};
} }
static __device__ __host__ auto MakeArgument(
APointers p_as,
BPointers p_bs,
const ck::Array<const void*, NumDTensor>& p_ds,
void* p_e,
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<long_index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<long_index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<long_index_t, NDimSpatial>& input_left_pads,
const ck::Array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
}; };
} // namespace device } // namespace device
......
...@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl ...@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
static constexpr auto spatial_offset = Number<3>{}; static constexpr auto spatial_offset = Number<3>{};
using GemmToConvFwdTransformer = using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>; TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{ MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
...@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl ...@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
: independent_filter_stride; : independent_filter_stride;
} }
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor {}, // not needed for A Descriptor
......
...@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>; using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
...@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(); conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
...@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} }
template <typename ELay> template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
...@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>; dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
...@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_; ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
...@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op}; cde_element_op};
} }
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op); cde_element_op);
} }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
......
...@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>; using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
...@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(); conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
...@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
template <typename CLay> template <typename CLay>
static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
...@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>; dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
...@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_; ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
......
...@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -316,7 +316,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization, ConvForwardSpecialization,
true /*SplitN*/, true /*SplitN*/,
ALayout, ALayout,
...@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
...@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(); conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
...@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
template <typename ELay> template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
...@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different. // Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast. // Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K = using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>; remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K = using BGridDesc_N_K =
...@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_.BatchStrideDs_(i) = compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_; ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_; index_t conv_N_per_block_;
...@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op}; cde_element_op};
} }
static auto
MakeArgument(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
APointers p_a, APointers p_as,
BPointers p_b, BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
...@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override const CDEElementwiseOperation& cde_element_op) override
{ {
return std::make_unique<Argument>(p_a, return std::make_unique<Argument>(p_as,
p_b, p_bs,
p_ds, p_ds,
p_e, p_e,
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
...@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op); cde_element_op);
} }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
......
...@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization, ConvForwardSpecialization,
true /*SplitN*/, true /*SplitN*/,
ADataType, ADataType,
...@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
...@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(); conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
...@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
} }
template <typename ELay> template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
...@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using EGridDesc_M_N = using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>; remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
...@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_; ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_; index_t conv_N_per_block_;
...@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false; return false;
} }
// Gridwise gemm v3 doesn't verify descriptors size
if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB())
{
return false;
}
// check Gridwise GEMM // check Gridwise GEMM
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
...@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op}; cde_element_op};
} }
static auto
MakeArgument(const void* p_as,
const void* p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op); cde_element_op);
} }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
......
...@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>; using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
...@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(); conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
...@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
} }
template <typename ELay> template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
...@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
} }
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K = using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>; remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K = using BGridDesc_N_K =
...@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride // D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType* p_e_grid_; EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_; typename GridwiseGemm::RsGridPointer p_rs_grid_;
GemmToConvFwdTransformer conv_to_gemm_transformer_; ConvToGemmFwdTransformer conv_to_gemm_transformer_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
......
...@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds = static constexpr auto BEnableLds =
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>; using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(); conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
...@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(); conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
...@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
template <typename ELay> template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(); conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
...@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc = using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer)); decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
using BGridDesc = using BGridDesc =
...@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_; ConvToGemmFwdTransformer conv_to_gemm_transformer_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
...@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op}; cde_element_op};
} }
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op); cde_element_op);
} }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
......
...@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl ...@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
using GemmToConvFwdTransformer = using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>; TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
...@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl ...@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths[I2] = C; b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N; c_g_n_k_wos_lengths[I1] = N;
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor {}, // not needed for A Descriptor
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n) const CGridDesc_M_N& c_grid_desc_m_n)
{ {
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
{
return false;
}
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 ...@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1, const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n) const CGridDesc_M_N& c_grid_desc_m_n)
{ {
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
{
return false;
}
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2); const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2); const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1); const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp" #include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
namespace ck { namespace ck {
// Define the common macro for gfx94x models // Define the common macro for gfx94x models
...@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) ...@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif #endif
} }
template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{
for(std::size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x); __host__ __device__ constexpr Y bf16_convert_rtn(X x);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public: public:
Argument(const Tensor<InDataType>& input, Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
: input_{input}, : input_{input},
output_{output}, output_{output},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
...@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const Tensor<InDataType>& input_; const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_; std::vector<long_index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_; std::vector<long_index_t> output_spatial_lengths_;
private: private:
void initOutputSpatialLengths() void initOutputSpatialLengths()
{ {
constexpr auto input_offset_to_spatial = 3; constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i) for(ck::long_index_t i = 0; i < NDimSpatial; ++i)
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back( output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + (output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
...@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.output_.GetLengths()[0]; const long_index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1]; const long_index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2]; const long_index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Wo + wo; long_index_t row = n * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
...@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
const index_t Ho = arg.output_spatial_lengths_[0]; const long_index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho) for(long_index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Ho * Wo + ho * Wo + wo; long_index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(hi >= 0 && if(hi >= 0 &&
...@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
const index_t Do = arg.output_spatial_lengths_[0]; const long_index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1]; const long_index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o) for(long_index_t d_o = 0; d_o < Do; ++d_o)
{ {
for(index_t ho = 0; ho < Ho; ++ho) for(long_index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{ {
auto di = auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * static_cast<ck::long_index_t>(ho *
...@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>(y * static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) - arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2];
++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
...@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) - x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]); static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
...@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
const ck::index_t G = arg.output_.GetLengths()[0]; const ck::long_index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1]; const ck::long_index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2]; const ck::long_index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo = const long_index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX = const long_index_t CZYX =
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) && if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
...@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
{ {
return Argument{input, return Argument{input,
output, output,
......
...@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input, Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input, Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi, const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x, Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi, const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x, Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment