Commit 898e1d22 authored by rocking's avatar rocking
Browse files

Revise layout of external api of group conv quantization

parent a92f4ea8
...@@ -17,14 +17,14 @@ namespace tensor_operation { ...@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( ...@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -52,10 +52,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( ...@@ -52,10 +52,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -68,10 +68,10 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( ...@@ -68,10 +68,10 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -83,10 +83,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( ...@@ -83,10 +83,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -99,10 +99,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( ...@@ -99,10 +99,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_GK_Tuple, GK_GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_F32_Tuple, I32_F32_Tuple,
...@@ -154,9 +154,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -154,9 +154,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>) is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
...@@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>) is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
......
...@@ -17,14 +17,14 @@ namespace tensor_operation { ...@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( ...@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -51,10 +51,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances( ...@@ -51,10 +51,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -67,10 +67,10 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( ...@@ -67,10 +67,10 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -82,10 +82,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( ...@@ -82,10 +82,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -97,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances( ...@@ -97,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
I32_Tuple, I32_Tuple,
...@@ -152,9 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -152,9 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>) is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
...@@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>) is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
......
...@@ -17,13 +17,13 @@ namespace tensor_operation { ...@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perchannel_quantization_int8_instances( void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
...@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances( ...@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
...@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( ...@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
...@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances( ...@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
GK_Tuple, GK_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
F32_Tuple, F32_Tuple,
...@@ -119,9 +119,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -119,9 +119,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
......
...@@ -17,13 +17,13 @@ namespace tensor_operation { ...@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perlayer_quantization_int8_instances( void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances( ...@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( ...@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances( ...@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -117,8 +117,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -117,8 +117,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>) is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment