"examples/vscode:/vscode.git/clone" did not exist on "b90b0570926083771ac838cc7906f31d88ba9f49"
Unverified Commit 46a675aa authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Add examples of Conv + reduction (data type: int4, int8, bf16, fp16, fp32) (#380)



* Refactor the design of DeviceGemmMultipleDMultipleR_Xdl_CShuffle

* Add 'DeviceGroupedConvFwdMultipleDMultipleR' interface

* Add DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle

* Remove 'GridwiseConvFwdMultipleDMultipleR_xdl_cshuffle'

* Add 'TransformConvFwdToGemm<>' utility class (from Chao)

* Use 'TransformConvFwdToGemm<>' to shorten code

* Fix ill-formed method declaration

* Re-implement MakeRGridDescriptor_M() function

* Change problem description

* Use macro to define layout types

* Define K-reduced output tensor layout types

* Let user to decide R output tensor layout

* Rename variables

* Add padding to the reduced output tensor if necessary

* Extract common code as helper method

* Remove debug message

* Add missing include directive

* Add partial fp16 Conv + Reduction example

* Add example verification code for 2D Conv problem

* Use type alias to simplify code

* Share code across different-dimension Conv problems

* Rename file/functions from run_conv_fwd* to run_convnd_fwd*

* Make example code more verbose

* Add code to support 1D & 3D Conv + Reduction on host

* Add more examples for data type: bf16, fp32

* Add example for int8

* Add custom target to group examples

* Use more general custom target name

* Change the description in error message

* Disable testing for example other than fp32

* Add examplel for int4 (just copy from int8)

* Fix wrong data type

* Use larger data type for intermediate tensors

* Finish int4 example

* Undefine macro PP_DEFINE_LAYOUT_TYPE() after use

* Use named variables to replace magic numbers

* Remove debug messages

* Use same A/B data type for host Conv in int4 example

* Add check for the 'RLayout' type argument

* Group same-dim-layouts together in 'LayoutSetting<>'

* Add 'final' specifier to utility classes

* Use different initialization method for examples

* Remove macro PP_DEFINE_LAYOUT_TYPE()

* Fix code-comment mismatch

* Use more reasonable initialization value for all data types

* Default use init_method=1 for all examples

* Remove never-used code

* Remove confusing out-of-date comments

* clean
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
Co-authored-by: default avatarChao Liu <lc.roy86@gmail.com>
parent 4df6d93f
add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <iterator>
#include <numeric>
#include <type_traits>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using BF16 = ck::bhalf_t;
using FP16 = ck::half_t;
using FP32 = float;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using I4 = ck::int4_t;
#endif
using I8 = std::int8_t;
using I32 = std::int32_t;
template <typename ALay, typename BLay, typename DELay, typename RLay>
struct LayoutSetting
{
using ALayout = ALay;
using BLayout = BLay;
using DELayout = DELay;
using RLayout = RLay;
};
template <ck::index_t NDimSpatial>
struct LayoutSettingSelector;
namespace ctl = ck::tensor_layout::convolution;
template <>
struct LayoutSettingSelector<1> final : LayoutSetting<ctl::GNWC, ctl::GKXC, ctl::GNWK, ctl::GNW>
{
};
template <>
struct LayoutSettingSelector<2> final : LayoutSetting<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK, ctl::GNHW>
{
};
template <>
struct LayoutSettingSelector<3> final
: LayoutSetting<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK, ctl::GNDHW>
{
};
template <ck::index_t NDimSpatial>
using ALayout = typename LayoutSettingSelector<NDimSpatial>::ALayout;
template <ck::index_t NDimSpatial>
using BLayout = typename LayoutSettingSelector<NDimSpatial>::BLayout;
template <ck::index_t NDimSpatial>
using DELayout = typename LayoutSettingSelector<NDimSpatial>::DELayout;
template <ck::index_t NDimSpatial>
using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout;
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
inline void print_help_msg()
{
std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
inline bool parse_cmd_args(int argc,
char* argv[],
ck::utils::conv::ConvParam& problem_size,
ExecutionConfig& config)
{
constexpr int num_execution_config_args =
3; // arguments for do_verification, init_method, time_kernel
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
constexpr int threshold_to_catch_all_args =
threshold_to_catch_partial_args + num_conv_param_leading_args;
if(argc == 1)
{
// use default
}
// catch only ExecutionConfig arguments
else if(argc == threshold_to_catch_partial_args)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
// catch both ExecutionConfig & ConvParam arguments
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
problem_size = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
}
else
{
print_help_msg();
return false;
}
return true;
}
inline HostTensorDescriptor
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
{
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_};
std::copy(begin(problem_size.output_spatial_lengths_),
end(problem_size.output_spatial_lengths_),
std::back_inserter(dimensions));
return HostTensorDescriptor(dimensions);
}
template <typename Lengths, typename Strides>
void unpack_host_tensor_descriptor(const HostTensorDescriptor& descriptor,
Lengths& lengths,
Strides& strides)
{
assert(size(descriptor.GetLengths()) == size(lengths));
std::copy_n(begin(descriptor.GetLengths()), size(descriptor.GetLengths()), begin(lengths));
assert(size(descriptor.GetStrides()) == size(strides));
std::copy_n(begin(descriptor.GetStrides()), size(descriptor.GetStrides()), begin(strides));
}
template <typename Range, typename OutputIterator>
auto copy(const Range& range, OutputIterator iter)
-> decltype(std::copy(std::begin(range), std::end(range), iter))
{
return std::copy(std::begin(range), std::end(range), iter);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = BF16;
using BDataType = BF16;
using AccDataType = FP32;
using CShuffleDataType = FP32;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using ReduceAccDataType = FP32;
using R0DataType = FP32;
using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = FP16;
using BDataType = FP16;
using AccDataType = FP32;
using CShuffleDataType = FP32;
using DsDataType = ck::Tuple<>;
using EDataType = FP16;
using ReduceAccDataType = FP32;
using R0DataType = FP32;
using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = FP32;
using BDataType = FP32;
using AccDataType = FP32;
using CShuffleDataType = FP32;
using DsDataType = ck::Tuple<>;
using EDataType = FP32;
using ReduceAccDataType = FP32;
using R0DataType = FP32;
using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#define BUILD_INT4_EXAMPLE
#include "common.hpp"
using ADataType = I4;
using BDataType = I4;
using KernelADataType = I8;
using KernelBDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using DsDataType = ck::Tuple<>;
using EDataType = I32;
using ReduceAccDataType = I32;
using R0DataType = I32;
using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = I8;
using BDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using DsDataType = ck::Tuple<>;
using EDataType = I32;
using ReduceAccDataType = I32;
using R0DataType = I32;
using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough>;
using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
template <ck::index_t NDimSpatial>
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
//######| NDimSpatial| ALayout| BLayout| DELayout| RLayout| AData| BData| AccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| Conv| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Fwd|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| Specialization| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
#ifdef BUILD_INT4_EXAMPLE
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
#else
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
#endif
template <ck::index_t NDimSpatial>
using HostInstance = ck::tensor_operation::host::ReferenceConvFwd
<NDimSpatial, ADataType, BDataType, EDataType, AElementOp, BElementOp, PassThrough>;
// clang-format on
template <ck::index_t NDimSpatial>
bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
const ExecutionConfig& config)
{
static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial");
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
const auto conv_input_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ALayout<NDimSpatial>>(
problem_size);
const auto conv_weight_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<BLayout<NDimSpatial>>(
problem_size);
const auto conv_output_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<DELayout<NDimSpatial>>(
problem_size);
const auto r0_desc = make_r0_host_tensor_descriptor(problem_size);
Tensor<ADataType> conv_input(conv_input_g_n_c_wis_desc);
Tensor<BDataType> conv_weight(conv_weight_g_k_c_xs_desc);
Tensor<EDataType> conv_output_device(conv_output_g_n_k_wos_desc);
Tensor<R0DataType> r0_device(r0_desc);
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input.begin(),
conv_input.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight.begin(),
conv_weight.end());
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input.begin(), conv_input.end());
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight.begin(),
conv_weight.end());
}
DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize());
DeviceMem conv_weight_device_buf(sizeof(BDataType) * conv_weight.mDesc.GetElementSpaceSize());
DeviceMem conv_output_device_buf(sizeof(EDataType) *
conv_output_device.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_device.mDesc.GetElementSpaceSize());
#ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> conv_input_converted(conv_input);
const Tensor<KernelBDataType> conv_weight_converted(conv_weight);
conv_input_device_buf.ToDevice(conv_input_converted.mData.data());
conv_weight_device_buf.ToDevice(conv_weight_converted.mData.data());
#else
conv_input_device_buf.ToDevice(conv_input.mData.data());
conv_weight_device_buf.ToDevice(conv_weight.mData.data());
#endif
std::array<ck::index_t, NDimSpatial + 3> conv_input_g_n_c_wis_lengths{},
conv_input_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> conv_weight_g_k_c_xs_lengths{},
conv_weight_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> conv_output_g_n_k_wos_lengths{},
conv_output_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 2> r0_lengths{}, r0_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}, conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}, input_right_pads{};
unpack_host_tensor_descriptor(
conv_input_g_n_c_wis_desc, conv_input_g_n_c_wis_lengths, conv_input_g_n_c_wis_strides);
unpack_host_tensor_descriptor(
conv_weight_g_k_c_xs_desc, conv_weight_g_k_c_xs_lengths, conv_weight_g_k_c_xs_strides);
unpack_host_tensor_descriptor(
conv_output_g_n_k_wos_desc, conv_output_g_n_k_wos_lengths, conv_output_g_n_k_wos_strides);
unpack_host_tensor_descriptor(r0_desc, r0_lengths, r0_strides);
copy(problem_size.conv_filter_strides_, begin(conv_filter_strides));
copy(problem_size.conv_filter_dilations_, begin(conv_filter_dilations));
copy(problem_size.input_left_pads_, begin(input_left_pads));
copy(problem_size.input_right_pads_, begin(input_right_pads));
// run Conv + Reduction on device
auto conv = DeviceInstance<NDimSpatial>{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(conv_input_device_buf.GetDeviceBuffer(),
conv_weight_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
conv_output_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer()},
conv_input_g_n_c_wis_lengths,
conv_input_g_n_c_wis_strides,
conv_weight_g_k_c_xs_lengths,
conv_weight_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
conv_output_g_n_k_wos_lengths,
conv_output_g_n_k_wos_strides,
r0_lengths,
r0_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
AElementOp{},
BElementOp{},
CDEElementOp{},
QsElementOp{},
RsElementOp{});
if(!conv.IsSupportedArgument(argument))
{
std::cerr << "wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<< std::endl;
return false;
}
const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
const std::size_t flop = problem_size.GetFlops();
const std::size_t num_btype = problem_size.GetByte<ADataType, BDataType, EDataType>();
const float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
const float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(config.do_verification)
{
Tensor<EDataType> conv_output_host(conv_output_g_n_k_wos_desc);
// run Conv + Reduction on host
auto ref_conv = HostInstance<NDimSpatial>{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(conv_input,
conv_weight,
conv_output_host,
problem_size.conv_filter_strides_,
problem_size.conv_filter_dilations_,
problem_size.input_left_pads_,
problem_size.input_right_pads_,
AElementOp{},
BElementOp{},
PassThrough{});
ref_invoker.Run(ref_argument);
Tensor<R0DataType> r0_host(r0_device.mDesc);
auto reduce0_op = RsThreadReduceOp{}[ck::Number<0>{}];
auto& output_dims = conv_output_g_n_k_wos_desc.GetLengths();
if constexpr(NDimSpatial == 1)
{
for(std::size_t g = 0; g < output_dims[0]; ++g)
{
for(std::size_t n = 0; n < output_dims[1]; ++n)
{
for(std::size_t w = 0; w < output_dims[3]; ++w)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
for(std::size_t k = 0; k < output_dims[2]; ++k)
{
auto e_val =
ck::type_convert<ReduceAccDataType>(conv_output_host(g, n, k, w));
reduce0_op(reduce0_acc, e_val);
}
r0_host(g, n, w) = ck::type_convert<R0DataType>(reduce0_acc);
}
}
}
}
else if constexpr(NDimSpatial == 2)
{
for(std::size_t g = 0; g < output_dims[0]; ++g)
{
for(std::size_t n = 0; n < output_dims[1]; ++n)
{
for(std::size_t h = 0; h < output_dims[3]; ++h)
{
for(std::size_t w = 0; w < output_dims[4]; ++w)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
for(std::size_t k = 0; k < output_dims[2]; ++k)
{
auto e_val = ck::type_convert<ReduceAccDataType>(
conv_output_host(g, n, k, h, w));
reduce0_op(reduce0_acc, e_val);
}
r0_host(g, n, h, w) = ck::type_convert<R0DataType>(reduce0_acc);
}
}
}
}
}
else if constexpr(NDimSpatial == 3)
{
for(std::size_t g = 0; g < output_dims[0]; ++g)
{
for(std::size_t n = 0; n < output_dims[1]; ++n)
{
for(std::size_t d = 0; d < output_dims[3]; ++d)
{
for(std::size_t h = 0; h < output_dims[4]; ++h)
{
for(std::size_t w = 0; w < output_dims[5]; ++w)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
for(std::size_t k = 0; k < output_dims[2]; ++k)
{
auto e_val = ck::type_convert<ReduceAccDataType>(
conv_output_host(g, n, k, d, h, w));
reduce0_op(reduce0_acc, e_val);
}
r0_host(g, n, d, h, w) = ck::type_convert<R0DataType>(reduce0_acc);
}
}
}
}
}
}
conv_output_device_buf.FromDevice(conv_output_device.mData.data());
r0_device_buf.FromDevice(r0_device.mData.data());
return ck::utils::check_err(conv_output_device.mData,
conv_output_host.mData,
"Error: incorrect results! (Matrix E)",
1e-5f,
1e-4f) &&
ck::utils::check_err(r0_device.mData,
r0_host.mData,
"Error: incorrect results! (Matrix R0)",
1e-5f,
1e-4f);
}
return true;
}
bool run_convnd_fwd_max_example(int argc, char* argv[])
{
ck::utils::conv::ConvParam problem_size{
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
ExecutionConfig config;
if(!parse_cmd_args(argc, argv, problem_size, config))
{
return false;
}
switch(problem_size.num_dim_spatial_)
{
case 1: return run_convnd_fwd_max<1>(problem_size, config);
case 2: return run_convnd_fwd_max<2>(problem_size, config);
case 3: return run_convnd_fwd_max<3>(problem_size, config);
}
return false;
}
......@@ -26,6 +26,7 @@ add_subdirectory(02_gemm_bilinear)
add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_add_add_fastgelu)
add_subdirectory(09_convnd_fwd)
add_subdirectory(10_convnd_fwd_multiple_d_multiple_reduce)
add_subdirectory(12_reduce)
add_subdirectory(13_pool2d_fwd)
add_subdirectory(14_gemm_xdl_requant_relu_requant)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Grouped Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// output : R0[G, N, Ho, Wo], R1[G, N, Ho, Wo], ...
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DELayout,
typename RLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename RsDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation>
struct DeviceGroupedConvFwdMultipleDMultipleR : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumRTensor = RsDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_multiple_r.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
template <index_t NumDTensor, index_t NumRTensor>
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE,
Array<ck::index_t, NumRTensor> BatchStrideRs)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE),
BatchStrideRs_(BatchStrideRs)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
__host__ __device__ constexpr auto GetRsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumRTensor> rs_offset;
static_for<0, NumRTensor, 1>{}(
[&](auto i) { rs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideRs_[i]); });
return rs_offset;
}
index_t BatchStrideA_;
index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
Array<ck::index_t, NumRTensor> BatchStrideRs_;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename RsPointer,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename RsGridDescriptor_MBlock_MPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batch_gemm_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
RsPointer p_rs_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
RsPointer p_rs_grid_grp;
static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size();
static_for<0, NumRTensor, 1>{}(
[&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_rs_grid_grp,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
rs_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = p_rs_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = rs_grid_desc_mblock_mperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = qs_element_op;
ignore = rs_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif
}
} // namespace
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DELayout,
typename RLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ReduceAccDataType,
typename RsDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation,
typename ThreadReduceOperations,
typename RsGlobalMemoryDataOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
index_t RThreadTransferDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleDMultipleR<NDimSpatial,
ALayout,
BLayout,
DELayout,
RLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
RsDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
RsElementwiseOperation,
QsElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumRTensor = RsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
template <typename Descriptor>
static auto GetPaddedRGridDescriptor(Descriptor descriptor, index_t MRaw)
{
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(
descriptor,
make_tuple(make_right_pad_transform(descriptor, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return descriptor;
}
}
template <typename RLay,
typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::GNW> ||
is_same_v<RLay, tensor_layout::convolution::GNHW> ||
is_same_v<RLay, tensor_layout::convolution::GNDHW>,
bool>::type = false>
static auto
MakeRGridDescriptor_M(const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& /* r_g_n_wos_strides */)
{
const index_t N = r_g_n_wos_lengths[1];
const index_t NHoWo = N * std::accumulate(r_g_n_wos_lengths.begin() + 2,
r_g_n_wos_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(NHoWo));
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
template <typename RLay,
typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::G_NW> ||
is_same_v<RLay, tensor_layout::convolution::G_NHW> ||
is_same_v<RLay, tensor_layout::convolution::G_NDHW> ||
is_same_v<RLay, tensor_layout::convolution::NWG> ||
is_same_v<RLay, tensor_layout::convolution::NHWG> ||
is_same_v<RLay, tensor_layout::convolution::NDHWG>,
bool>::type = false>
static auto MakeRGridDescriptor_M(const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides)
{
const index_t N = r_g_n_wos_lengths[1];
const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2];
const index_t NHoWo = N * std::accumulate(r_g_n_wos_lengths.begin() + 2,
r_g_n_wos_lengths.begin() + 2 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto r_grid_desc_mraw =
make_naive_tensor_descriptor(make_tuple(NHoWo), make_tuple(WoStride));
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
using AGridDesc_M_K = remove_cvref_t<decltype(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ReduceAccDataType,
RsDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
QsElementwiseOperation,
RsElementwiseOperation,
ThreadReduceOperations,
InMemoryDataOperationEnum::Set,
RsGlobalMemoryDataOperation,
AGridDesc_M_K,
BGridDesc_N_K,
EGridDesc_M_N,
RGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
p_rs_grid_{}, // FIXME
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<DELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
r_grid_desc_m_{
DeviceOp::MakeRGridDescriptor_M<RLayout>(r_g_n_wos_lengths, r_g_n_wos_strides)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
rs_grid_desc_mblock_mperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
qs_element_op_{qs_element_op},
rs_element_op_{rs_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
e_grid_desc_m_n_,
r_grid_desc_m_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
// populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DELayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_(i));
});
// populate pointer for Rs
static_for<0, NumRTensor, 1>{}([&](auto i) {
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
// R pointer
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
rs_grid_desc_mblock_mperblock_(i) =
GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_);
});
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
EGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
RGridDesc_M r_grid_desc_m_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
StaticallyIndexedArray<typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, NumRTensor>
rs_grid_desc_mblock_mperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
ComputePtrOffsetOfStridedBatch<NumDTensor, NumRTensor> compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
QsElementwiseOperation qs_element_op_;
RsElementwiseOperation rs_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
arg.r_grid_desc_m_,
arg.block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) *
arg.a_g_n_c_wis_lengths_[0]; // Group count
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
typename GridwiseGemm::RsGridPointer,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
QsElementwiseOperation,
RsElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ck::StaticallyIndexedArray<
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
NumRTensor>,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor, NumRTensor>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.p_rs_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.qs_element_op_,
arg.rs_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.rs_grid_desc_mblock_mperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check device
if(get_device_name() == "gfx908")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else if(get_device_name() == "gfx90a")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
{
return false;
}
}
else
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of Ds
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
// FIXME: layout
if constexpr(is_same_v<DELayout, ctc::G_NW_K> || is_same_v<DELayout, ctc::G_NHW_K> ||
is_same_v<DELayout, ctc::G_NDHW_K> || is_same_v<DELayout, ctc::GNWK> ||
is_same_v<DELayout, ctc::GNHWK> || is_same_v<DELayout, ctc::GNDHWK> ||
is_same_v<DELayout, ctc::NWGK> || is_same_v<DELayout, ctc::NHWGK> ||
is_same_v<DELayout, ctc::NDHWGK>)
{
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
valid = false;
}
}
else
{
valid = false;
}
});
if(!valid)
{
return false;
}
// check vector access of E
if constexpr(is_same_v<DELayout, ctc::G_NW_K> || is_same_v<DELayout, ctc::G_NHW_K> ||
is_same_v<DELayout, ctc::G_NDHW_K> || is_same_v<DELayout, ctc::GNWK> ||
is_same_v<DELayout, ctc::GNHWK> || is_same_v<DELayout, ctc::GNDHWK> ||
is_same_v<DELayout, ctc::NWGK> || is_same_v<DELayout, ctc::NHWGK> ||
is_same_v<DELayout, ctc::NDHWGK>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of R
if constexpr(!(is_same_v<RLayout, ctc::G_NW> || is_same_v<RLayout, ctc::G_NHW> ||
is_same_v<RLayout, ctc::G_NDHW> || is_same_v<RLayout, ctc::GNW> ||
is_same_v<RLayout, ctc::GNHW> || is_same_v<RLayout, ctc::GNDHW> ||
is_same_v<RLayout, ctc::NWG> || is_same_v<RLayout, ctc::NHWG> ||
is_same_v<RLayout, ctc::NDHWG>))
{
return false;
}
// check Gridwise GEMM
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.e_grid_desc_m_n_,
arg.r_grid_desc_m_,
arg.block_2_etile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
p_rs,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
r_g_n_wos_lengths,
r_g_n_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const QsElementwiseOperation& qs_element_op,
const RsElementwiseOperation& rs_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
p_rs,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
r_g_n_wos_lengths,
r_g_n_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -93,7 +93,7 @@ struct GNDHWC : public BaseTensorLayout
};
// input tensor
// packed GNWC/GNHWC/GNDHWC
// packed NWGC/NHWGC/NDHWGC
struct NWGC : public BaseTensorLayout
{
static constexpr const char* name = "NWGC";
......@@ -330,6 +330,54 @@ struct G_NDHW_K : public BaseTensorLayout
static constexpr const char* name = "G_NDHW_K";
};
// K-reduced output tensor (packed)
struct GNW : public BaseTensorLayout
{
static constexpr const char* name = "GNW";
};
struct GNHW : public BaseTensorLayout
{
static constexpr const char* name = "GNHW";
};
struct GNDHW : public BaseTensorLayout
{
static constexpr const char* name = "GNDHW";
};
// K-reduced output tensor (packed)
struct NWG : public BaseTensorLayout
{
static constexpr const char* name = "NWG";
};
struct NHWG : public BaseTensorLayout
{
static constexpr const char* name = "NHWG";
};
struct NDHWG : public BaseTensorLayout
{
static constexpr const char* name = "NDHWG";
};
// K-reduced output tensor (strided)
struct G_NW : public BaseTensorLayout
{
static constexpr const char* name = "G_NW";
};
struct G_NHW : public BaseTensorLayout
{
static constexpr const char* name = "G_NHW";
};
struct G_NDHW : public BaseTensorLayout
{
static constexpr const char* name = "G_NDHW";
};
} // namespace convolution
template <
......
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace utils {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment