Unverified Commit f2398f61 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Introduce multiABD api and deprecate multiD (#1035)

* Introduce multiABD api and deprecate multiD

* Replace multiD with multiABD

* Mark structures as deprecated

* Change doxygen deprecated to note to avoid warnings
parent 5356c4a9
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -20,96 +20,96 @@ namespace instance { ...@@ -20,96 +20,96 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>& Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>& Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>& Add_Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>& Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
// piecewise activation function // piecewise activation function
...@@ -123,7 +123,7 @@ template <ck::index_t NumDimSpatial, ...@@ -123,7 +123,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -137,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -137,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>> Add_Activation_Mul2_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -193,7 +194,7 @@ template <ck::index_t NumDimSpatial, ...@@ -193,7 +194,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -207,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -207,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>> Add_Mul2_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -20,94 +20,96 @@ namespace instance { ...@@ -20,94 +20,96 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>& Add_Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>& Add_Activation_Mul_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
I32_Tuple, int8_t,
int8_t, I32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>& PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<PassThrough>>>>& Add_Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Add_Activation_Mul_Clamp<Relu>>>>& Add_Activation_Mul_Clamp<Relu>>>>&
instances); instances);
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
I32_Tuple, int8_t,
int8_t, I32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>& PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
// piecewise activation function // piecewise activation function
...@@ -121,7 +123,7 @@ template <ck::index_t NumDimSpatial, ...@@ -121,7 +123,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -135,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -135,18 +137,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>> Add_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -191,7 +194,7 @@ template <ck::index_t NumDimSpatial, ...@@ -191,7 +194,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -205,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -205,18 +208,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>> Add_Mul_Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
DsLayout, WeiLayout,
OutLayout, DsLayout,
InDataType, OutLayout,
WeiDataType, InDataType,
DsDataType, WeiDataType,
OutDataType, DsDataType,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,63 +19,65 @@ namespace instance { ...@@ -19,63 +19,65 @@ namespace instance {
#ifdef DL_KERNELS #ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perchannel_quantization_int8_instances( void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
F32_Tuple, int8_t,
int8_t, F32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul2_Clamp<Relu>>>>& Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
GK_Tuple, GKYXC,
NHWGK, GK_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
F32_Tuple, int8_t,
int8_t, F32_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul2_Clamp<Relu>>>>& Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -88,7 +90,7 @@ template <ck::index_t NumDimSpatial, ...@@ -88,7 +90,7 @@ template <ck::index_t NumDimSpatial,
typename DsDataType, typename DsDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -102,18 +104,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -102,18 +104,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>> Activation_Mul2_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
GK_Tuple, WeiLayout,
OutLayout, GK_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
F32_Tuple, WeiDataType,
OutDataType, F32_Tuple,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,63 +19,65 @@ namespace instance { ...@@ -19,63 +19,65 @@ namespace instance {
#ifdef DL_KERNELS #ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perlayer_quantization_int8_instances( void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
Empty_Tuple, int8_t,
int8_t, Empty_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul_Clamp<Relu>>>>& Activation_Mul_Clamp<Relu>>>>&
instances); instances);
#endif #endif
void add_device_conv2d_xdl_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<
NHWGC, std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
int8_t, NHWGK,
int8_t, int8_t,
Empty_Tuple, int8_t,
int8_t, Empty_Tuple,
PassThrough, int8_t,
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
Activation_Mul_Clamp<Relu>>>>& Activation_Mul_Clamp<Relu>>>>&
instances); instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -86,7 +88,7 @@ template <ck::index_t NumDimSpatial, ...@@ -86,7 +88,7 @@ template <ck::index_t NumDimSpatial,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename Activation> typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -100,18 +102,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -100,18 +102,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>> Activation_Mul_Clamp<Activation>>>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial, using DeviceOp =
InLayout, DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
WeiLayout, InLayout,
Empty_Tuple, WeiLayout,
OutLayout, Empty_Tuple,
InDataType, OutLayout,
WeiDataType, InDataType,
Empty_Tuple, WeiDataType,
OutDataType, Empty_Tuple,
ck::tensor_operation::element_wise::PassThrough, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>; ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC, GNWC,
GKXC, GKXC,
Empty_Tuple, Empty_Tuple,
GNWK, GNWK,
BF16, BF16,
BF16, BF16,
Empty_Tuple, Empty_Tuple,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<1, device_grouped_conv_fwd_xdl_bf16_instances<1,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC, GNWC,
GKXC, GKXC,
Empty_Tuple, Empty_Tuple,
GNWK, GNWK,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f16_instances<1, device_grouped_conv_fwd_xdl_f16_instances<1,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC, GNWC,
GKXC, GKXC,
Empty_Tuple, Empty_Tuple,
GNWK, GNWK,
F32, F32,
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_instances<1, device_grouped_conv_fwd_xdl_f32_instances<1,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
GNWC, GNWC,
GKXC, GKXC,
Empty_Tuple, Empty_Tuple,
GNWK, GNWK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_int8_instances<1, device_grouped_conv_fwd_xdl_int8_instances<1,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_dl_f16_instances<GNHWC, device_grouped_conv2d_fwd_dl_f16_instances<GNHWC,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
F32, F32,
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_dl_f32_instances<GNHWC, device_grouped_conv2d_fwd_dl_f32_instances<GNHWC,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_dl_f16_instances<NHWGC, device_grouped_conv2d_fwd_dl_f16_instances<NHWGC,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
F32, F32,
F32, F32,
Empty_Tuple, Empty_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_dl_f32_instances<NHWGC, device_grouped_conv2d_fwd_dl_f32_instances<NHWGC,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2, device_grouped_conv_fwd_wmma_f16_instances<2,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2, device_grouped_conv_fwd_wmma_f16_instances<2,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2, device_grouped_conv_fwd_wmma_f16_instances<2,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2, device_grouped_conv_fwd_wmma_f16_instances<2,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2, device_grouped_conv_fwd_wmma_i8_instances<2,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2, device_grouped_conv_fwd_wmma_i8_instances<2,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2, device_grouped_conv_fwd_wmma_i8_instances<2,
......
...@@ -10,18 +10,18 @@ namespace device { ...@@ -10,18 +10,18 @@ namespace device {
namespace instance { namespace instance {
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
int8_t, int8_t,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2, device_grouped_conv_fwd_wmma_i8_instances<2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment