Unverified Commit 1462ee22 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into gridwise_2d

parents 2c4305b2 d1567094
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwd<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_bias_perchannel_quantization_int8_instances(op_ptrs);
else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -23,7 +23,7 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances( ...@@ -23,7 +23,7 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances(
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GK_TUPLE, GK_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
...@@ -38,7 +38,7 @@ void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( ...@@ -38,7 +38,7 @@ void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GK_TUPLE, GK_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
...@@ -91,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -91,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_TUPLE> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, GNHWK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul2_Clamp<Relu>>>>&
instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
GK_Tuple,
OutLayout,
InDataType,
WeiDataType,
F32_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_perchannel_quantization_int8_instances(op_ptrs);
else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_relu_perchannel_quantization_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,4 +3,8 @@ add_instance_library(device_batchnorm_instance ...@@ -3,4 +3,8 @@ add_instance_library(device_batchnorm_instance
device_batchnorm_forward_f32_instance.cpp device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_f64_instance.cpp device_batchnorm_forward_f64_instance.cpp
device_batchnorm_backward_f16_instance.cpp
device_batchnorm_backward_f32_instance.cpp
device_batchnorm_backward_bf16_instance.cpp
device_batchnorm_backward_f64_instance.cpp
) )
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_bf16_blockwise_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_bf16_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>& instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_bf16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_bf16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f16_blockwise_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f16_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<
std::unique_ptr<DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_f16_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_f16_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f32_blockwise_instances = std::tuple<
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f32_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<
std::unique_ptr<DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_f32_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_f32_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F64 = double;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f64_blockwise_instances = std::tuple<
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
// clang-format off
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
using device_batchnorm_backward_f64_multiblock_instances =
std::tuple <
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
>;
// clang-format on
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<
std::unique_ptr<DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&
instances)
{
add_device_operation_instances(
instances, device_batchnorm_backward_f64_blockwise_instances<4, 3, PassThrough>{});
add_device_operation_instances(
instances, device_batchnorm_backward_f64_multiblock_instances<4, 3, PassThrough>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -22,6 +21,7 @@ using WeiDataType = ck::half_t; ...@@ -22,6 +21,7 @@ using WeiDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -44,45 +44,47 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -44,45 +44,47 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple< using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances = std::tuple< using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances = std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances = using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, Empty_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
OutDataType, WeiDataType,
InElementOp, Empty_Tuple,
WeiElementOp, OutDataType,
OutElementOp>>>& instances) InElementOp,
WeiElementOp,
OutElementOp>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances{}); device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances{});
......
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -22,6 +21,7 @@ using WeiDataType = float; ...@@ -22,6 +21,7 @@ using WeiDataType = float;
using AccDataType = float; using AccDataType = float;
using OutDataType = float; using OutDataType = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -44,46 +44,51 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -44,46 +44,51 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances = std::tuple< using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances = std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // clang-format off
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances = std::tuple< using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances = std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // clang-format off
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances = using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // clang-format off
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, Empty_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
OutDataType, WeiDataType,
InElementOp, Empty_Tuple,
WeiElementOp, OutDataType,
OutElementOp>>>& instances) InElementOp,
WeiElementOp,
OutElementOp>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances{}); device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances{});
......
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -22,6 +21,7 @@ using WeiDataType = int8_t; ...@@ -22,6 +21,7 @@ using WeiDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using OutDataType = int8_t; using OutDataType = int8_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -44,46 +44,48 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -44,46 +44,48 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple< using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances = std::tuple< using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances = std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances = using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
// ###############################| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ###############################| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ###############################| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, Empty_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
OutDataType, WeiDataType,
InElementOp, Empty_Tuple,
WeiElementOp, OutDataType,
OutElementOp>>>& instances) InElementOp,
WeiElementOp,
OutElementOp>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances{}); device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances{});
......
add_instance_library(device_quantization_instance add_instance_library(device_quantization_instance
device_conv2d_xdl_bias_quant_int8_instance.cpp device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp
device_conv2d_xdl_quant_int8_instance.cpp device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp
device_conv2d_xdl_perchannel_quantization_int8_instance.cpp
device_conv2d_xdl_perlayer_quantization_int8_instance.cpp
) )
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_conv2d_xdl_int8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_conv2d_bias_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_Clamp,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_Clamp,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Mul2_Clamp,
ConvFwd1x1S1P0>{});
}
void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Relu_Mul2_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Relu_Mul2_Clamp,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Relu_Mul2_Clamp,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_GK_Tuple,
I32_F32_Tuple,
Add_Relu_Mul2_Clamp,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_conv2d_xdl_int8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_conv2d_bias_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, I32_Tuple, Add_Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, I32_Tuple, Add_Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, I32_Tuple, Add_Mul_Clamp, ConvFwd1x1S1P0>{});
}
void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Relu_Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_Tuple,
I32_Tuple,
Add_Relu_Mul_Clamp,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, I32_Tuple, Add_Relu_Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_conv2d_int8_32Ds_instances<GK_Tuple,
I32_Tuple,
Add_Relu_Mul_Clamp,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GK = ck::tensor_layout::convolution::G_K;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using GK_Tuple = ck::Tuple<GK>;
using I32_Tuple = ck::Tuple<int32_t>;
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
static constexpr ck::index_t NDimSpatial = 2;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
// TODO - Add more instances
template <typename OutElementOp, ConvolutionForwardSpecialization ConvSpec>
// clang-format off
using device_conv2d_int8_instances =
std::tuple <
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>
>;
// clang-format on
void add_device_conv2d_bias_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
ck::Tuple<GK>,
GNHWK,
int8_t,
int8_t,
ck::Tuple<int32_t>,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_int8_instances<Add_Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(instances,
device_conv2d_int8_instances<Add_Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_conv2d_int8_instances<Add_Mul_Clamp, ConvFwd1x1S1P0>{});
}
void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
ck::Tuple<GK>,
GNHWK,
int8_t,
int8_t,
ck::Tuple<int32_t>,
int8_t,
PassThrough,
PassThrough,
Add_Relu_Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances, device_conv2d_int8_instances<Add_Relu_Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances, device_conv2d_int8_instances<Add_Relu_Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances, device_conv2d_int8_instances<Add_Relu_Mul_Clamp, ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GK = ck::tensor_layout::convolution::G_K;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
using I32_Tuple = ck::Tuple<int32_t>;
using F32_Tuple = ck::Tuple<float>;
using I32_F32_Tuple = ck::Tuple<int32_t, float>;
using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Relu>;
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<PassThrough>;
using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Relu>;
using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<PassThrough>;
using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Relu>;
static constexpr ck::index_t NDimSpatial = 2;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
template <typename DsLayout,
typename DsDatatype,
typename OutElementOp,
ConvolutionForwardSpecialization ConvSpec>
// clang-format off
using device_conv2d_int8_instances =
std::tuple <
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>
>;
// clang-format on
// for conv + multiple of 32 bit Ds. bit of Ds will affect the ScalarPerVector of C
template <typename DsLayout,
typename DsDatatype,
typename OutElementOp,
ConvolutionForwardSpecialization ConvSpec>
// clang-format off
using device_conv2d_int8_32Ds_instances =
std::tuple <
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>
>;
// clang-format on
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_conv2d_xdl_int8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_conv2d_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul2_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Mul2_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Mul2_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Mul2_Clamp, ConvFwd1x1S1P0>{});
}
void add_device_conv2d_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Relu_Mul2_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Relu_Mul2_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Relu_Mul2_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_32Ds_instances<GK_Tuple, F32_Tuple, Relu_Mul2_Clamp, ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "device_conv2d_xdl_int8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_conv2d_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Mul_Clamp, ConvFwd1x1S1P0>{});
}
void add_device_conv2d_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Relu_Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Relu_Mul_Clamp, ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Relu_Mul_Clamp, ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_conv2d_int8_instances<Empty_Tuple, Empty_Tuple, Relu_Mul_Clamp, ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment