Commit 95a83c6e authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into wavelet_model

parents 5b7c2432 892a8d76
...@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) { auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
std::size_t N = arg.output_.GetLengths()[1];
std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[4];
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho) for(std::size_t ho = 0; ho < Ho; ++ho)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo) for(std::size_t wo = 0; wo < Wo; ++wo)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
......
...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1]; size_t N = acc.mDesc.GetLengths()[1];
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc_sq({M});
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc({M});
Tensor<ComputeDataType> acc_layernorm(acc); Tensor<ComputeDataType> acc_layernorm(acc);
// reduce N dim // reduce N dim
......
...@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_); auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_);
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
} }
} }
......
...@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator
{ {
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
} }
// max and sum reduction with final reduced values of dim=0 is a scalar so give it
// appropriate lengths of {1}
if(arg.sm_scalar_dims_.size() == 0)
{
scalar_lengths.push_back(1);
}
Tensor<AccDataType> reduce_max(scalar_lengths); Tensor<AccDataType> reduce_max(scalar_lengths);
reduce_max.GenerateTensorValue( reduce_max.GenerateTensorValue(
...@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<AccDataType> reduce_sum(scalar_lengths); Tensor<AccDataType> reduce_sum(scalar_lengths);
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
// when final reduced values is of dim=0, the index will be transformed into empty
// std::vector which is actually a valid input for Tensor::operator(std::vector) and
// internally accesses 0'th element
auto to_sm_scalar_idx = [&](auto idx) { auto to_sm_scalar_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx; std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_scalar_dims_) for(index_t dim : arg.sm_scalar_dims_)
...@@ -77,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -77,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
}; };
arg.in_.ForEach([&](auto& self, auto idx) { arg.in_.ForEach([&](auto& self, auto idx) {
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)), reduce_max(to_sm_scalar_idx(idx)) = std::max(
static_cast<AccDataType>(self(idx))); reduce_max(to_sm_scalar_idx(idx)), ck::type_convert<AccDataType>(self(idx)));
}); });
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") << // LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
...@@ -87,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -87,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<AccDataType> in_stable(arg.in_.mDesc); Tensor<AccDataType> in_stable(arg.in_.mDesc);
in_stable.ForEach([&](auto& self, auto idx) { in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x)) // numerator = exp(x - max(x))
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) - self(idx) = std::exp(ck::type_convert<AccDataType>(arg.in_(idx)) -
reduce_max(to_sm_scalar_idx(idx))); reduce_max(to_sm_scalar_idx(idx)));
}); });
...@@ -102,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -102,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
// std::endl; // std::endl;
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) + AccDataType temp_result =
arg.beta_ * self(idx); arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
arg.beta_ * self(idx);
self(idx) = ck::type_convert<OutDataType>(temp_result);
}); });
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#pragma once #pragma once
#include <cstdlib> #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -15,6 +18,8 @@ using F64 = double; ...@@ -15,6 +18,8 @@ using F64 = double;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t;
using I32 = int32_t;
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
...@@ -23,6 +28,8 @@ using F16_F16_Tuple = ck::Tuple<F16, F16>; ...@@ -23,6 +28,8 @@ using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -70,13 +77,25 @@ using NWGK = ck::tensor_layout::convolution::NWGK; ...@@ -70,13 +77,25 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK; using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
//
using GK = ck::tensor_layout::convolution::G_K;
using GK_TUPLE = ck::Tuple<GK>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
template <typename DeviceOp> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory; struct DeviceOperationInstanceFactory;
} // namespace instance } // namespace instance
......
...@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g ...@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); false>>>& instances);
void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
Col,
Row,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
true>>>& instances);
template <typename ALayout, template <typename ALayout,
typename B0Layout, typename B0Layout,
...@@ -39,7 +56,8 @@ template <typename ALayout, ...@@ -39,7 +56,8 @@ template <typename ALayout,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType> typename CDataType,
bool MaskOutUpperTriangle>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory< ...@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory<
CDataType, CDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>> MaskOutUpperTriangle>>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout, using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory< ...@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory<
CDataType, CDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>; MaskOutUpperTriangle>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -79,8 +99,16 @@ struct DeviceOperationInstanceFactory< ...@@ -79,8 +99,16 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>) is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{ {
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( if constexpr(MaskOutUpperTriangle)
op_ptrs); {
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
else
{
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
} }
} }
return op_ptrs; return op_ptrs;
......
...@@ -17,63 +17,131 @@ namespace tensor_operation { ...@@ -17,63 +17,131 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
template <ck::index_t... Is> void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
using S = ck::Sequence<Is...>; std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
using CPermuteNumDims_G_M_O = void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row, std::vector<std::unique_ptr<
Col, DeviceBatchedGemmSoftmaxGemmPermute<2,
Row, 1,
CPermuteNumDims_G_M_O, 1,
F16, 1,
F16, 1,
F16, BF16,
F16, BF16,
PassThrough, BF16,
PassThrough, BF16,
Scale, ck::Tuple<>,
PassThrough, ck::Tuple<>,
PassThrough>>>& instances); PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
template <typename ALayout, void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
typename B0Layout, std::vector<
typename B1Layout, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
typename CPermuteNumDims_G_M_Gemm1N, 1,
typename ADataType, 1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType> typename CDataType,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
B0Layout, 1,
B1Layout, 1,
CPermuteNumDims_G_M_Gemm1N, 1,
1,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ck::Tuple<>,
ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>> PassThrough,
MaskingSpec>>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<ALayout, using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
B0Layout, 1,
B1Layout, 1,
CPermuteNumDims_G_M_Gemm1N, 1,
1,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ck::Tuple<>,
ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>; PassThrough,
MaskingSpec>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -82,11 +150,28 @@ struct DeviceOperationInstanceFactory< ...@@ -82,11 +150,28 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
is_same_v<B1Layout, Row> && {
is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>) add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{ {
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
} }
} }
......
...@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// conv2d dl
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward data // conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
...@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
} }
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdWeight<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvBwdWeight<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_elementwise_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceElementwiseNormalization<ck::Tuple<F16, F16>,
F16,
F16,
F32,
F16,
element_wise::Add,
PassThrough,
2,
1>>>&);
template <typename InDataTypeTuple,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwiseNormalization<
InDataTypeTuple,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Add,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceElementwiseNormalization<InDataTypeTuple,
GammaDataType,
BetaDataType,
F32,
YDataType,
ck::tensor_operation::element_wise::Add,
ck::tensor_operation::element_wise::PassThrough,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<GammaDataType, F16> && is_same_v<BetaDataType, F16> &&
is_same_v<YDataType, F16>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_elementwise_normalization_rank_2_1_f16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d backward data
void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
typename InLayout,
typename OutDataType,
typename WeiDataType,
typename InDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<
NumDimSpatial,
OutLayout,
WeiLayout,
Empty_Tuple,
InLayout,
OutDataType,
WeiDataType,
Empty_Tuple,
InDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp =
DeviceGroupedConvBwdDataMultipleD<NumDimSpatial,
OutLayout,
WeiLayout,
Empty_Tuple,
InLayout,
OutDataType,
WeiDataType,
Empty_Tuple,
InDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvBwdWeight<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_TUPLE,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_TUPLE,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>&
instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_TUPLE> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_bias_perlayer_quantization_int8_instances(op_ptrs);
else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
#pragma once #pragma once
#include <cstdlib> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwd<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<Relu>>>>&
instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
Empty_Tuple,
OutLayout,
InDataType,
WeiDataType,
Empty_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
Empty_Tuple,
OutLayout,
InDataType,
WeiDataType,
Empty_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_perlayer_quantization_int8_instances(op_ptrs);
else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_relu_perlayer_quantization_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,24 +3,77 @@ ...@@ -3,24 +3,77 @@
#pragma once #pragma once
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp"
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp"
namespace ck { namespace ck {
...@@ -63,33 +65,20 @@ using reduce_configuration_2_instances_blockwise = std::tuple< ...@@ -63,33 +65,20 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>; >;
#endif #endif
template <ReduceTensorOp ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
ReduceTensorOp ReduceOpId, typename ReduceOperation,
typename InElementwiseOp,
typename AccElementwiseOp,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool OutputIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>>& device_op_instances) std::vector<DeviceReducePtr<Rank, NumReduceDim, InElementwiseOp, AccElementwiseOp>>&
device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool OutputIndex = Indexable && UseIndex;
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}( static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) { [&](auto i) {
using cfg1 = remove_cvref_t<decltype( using cfg1 = remove_cvref_t<decltype(
...@@ -107,8 +96,8 @@ void add_device_reduce_instance_blockwise( ...@@ -107,8 +96,8 @@ void add_device_reduce_instance_blockwise(
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOp,
AccElementwiseOperation, AccElementwiseOp,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
PropagateNan, PropagateNan,
OutputIndex, OutputIndex,
...@@ -128,52 +117,6 @@ void add_device_reduce_instance_blockwise( ...@@ -128,52 +117,6 @@ void add_device_reduce_instance_blockwise(
}); });
}; };
#define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<bool>(NanOpt), \
static_cast<bool>(IndicesOpt), \
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<bool>(NanOpt), \
static_cast<bool>(IndicesOpt), \
Rank, \
NumReduceDim)
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern template void add_device_reduce_instance_blockwise<BF16, F32, BF16, 4, 3, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 3, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_blockwise<BF16, F32, BF16, 4, 4, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 4, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_blockwise<BF16, F32, BF16, 4, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<4, 1, PassThrough, PassThrough>>&);
extern template void add_device_reduce_instance_blockwise<BF16, F32, BF16, 2, 1, ReduceAdd, PassThrough, PassThrough, false, false>(std::vector<DeviceReducePtr<2, 1, PassThrough, PassThrough>>&);
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment