Commit dddc2115 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 6e01019b 08eb1769
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -58,7 +58,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g ...@@ -58,7 +58,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif
#ifdef __bf16__
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2, DeviceBatchedGemmSoftmaxGemmPermute<2,
...@@ -100,6 +101,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf ...@@ -100,6 +101,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances); instances);
#endif
template <typename ADataType, template <typename ADataType,
typename B0DataType, typename B0DataType,
...@@ -146,7 +148,7 @@ struct DeviceOperationInstanceFactory< ...@@ -146,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
...@@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory< ...@@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __bf16__
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> && else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>) is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{ {
...@@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory< ...@@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp32__
// float // float
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -65,7 +65,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn ...@@ -65,7 +65,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
#ifdef __fp64__
// double // double
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -114,7 +115,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn ...@@ -114,7 +115,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
// Contraction + Bilinear // Contraction + Bilinear
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
...@@ -149,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -149,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp32__
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> && if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<DDataType, float> && is_same_v<EDataType, float>) is_same_v<DDataType, float> && is_same_v<EDataType, float>)
{ {
...@@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __fp64__
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> && if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<DDataType, double> && is_same_v<EDataType, double>) is_same_v<DDataType, double> && is_same_v<EDataType, double>)
{ {
...@@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp32__
// float // float
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -65,7 +65,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc ...@@ -65,7 +65,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale>>>& instances);
#endif
#ifdef __fp64__
// double // double
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2, std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
...@@ -114,7 +115,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc ...@@ -114,7 +115,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>>>& instances); Scale>>>& instances);
#endif
// Contraction + Scale // Contraction + Scale
template <index_t NumDimM, template <index_t NumDimM,
index_t NumDimN, index_t NumDimN,
...@@ -148,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -148,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp32__
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> && if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<EDataType, float>) is_same_v<EDataType, float>)
{ {
...@@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
#ifdef __fp64__
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> && if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<EDataType, double>) is_same_v<EDataType, double>)
{ {
...@@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra ...@@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __bf16__
// conv1d backward data // conv1d backward data
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1, std::vector<std::unique_ptr<DeviceConvBwdData<1,
...@@ -29,16 +29,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( ...@@ -29,16 +29,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp16__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __fp32__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1, std::vector<std::unique_ptr<DeviceConvBwdData<1,
...@@ -52,6 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( ...@@ -52,6 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef __bf16__
// conv2d backward data // conv2d backward data
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -64,7 +68,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -64,7 +68,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp16__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -76,7 +81,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -76,7 +81,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp32__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -88,6 +94,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( ...@@ -88,6 +94,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -101,6 +108,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -101,6 +108,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef DL_KERNELS
#ifdef __fp16__
// conv2d dl // conv2d dl
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -113,7 +122,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( ...@@ -113,7 +122,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp32__
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -125,6 +135,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( ...@@ -125,6 +135,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -138,6 +149,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( ...@@ -138,6 +149,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#endif
#ifdef __bf16__
// conv3d backward data // conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -150,7 +163,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( ...@@ -150,7 +163,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp16__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC, NDHWC,
...@@ -162,7 +176,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( ...@@ -162,7 +176,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp32__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC, NDHWC,
...@@ -174,6 +189,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( ...@@ -174,6 +189,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -229,20 +245,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -229,20 +245,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && #ifdef __fp16__
is_same_v<OutDataType, half_t>) if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && #endif
is_same_v<WeiDataType, ck::bhalf_t> && #ifdef __bf16__
is_same_v<OutDataType, ck::bhalf_t>) if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
} }
...@@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#endif
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && #ifdef __fp16__
is_same_v<OutDataType, half_t>) if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
#endif
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && #endif
is_same_v<WeiDataType, ck::bhalf_t> && #ifdef __bf16__
is_same_v<OutDataType, ck::bhalf_t>) if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
#endif
} }
#endif #endif
} }
...@@ -286,20 +314,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -286,20 +314,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && #ifdef __fp16__
is_same_v<OutDataType, half_t>) if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && #endif
is_same_v<WeiDataType, ck::bhalf_t> && #ifdef __bf16__
is_same_v<OutDataType, ck::bhalf_t>) if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
} }
......
...@@ -18,11 +18,17 @@ namespace device { ...@@ -18,11 +18,17 @@ namespace device {
namespace instance { namespace instance {
// conv2d forward // conv2d forward
#ifdef __fp16__
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef __bf16__
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2, std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC, NHWC,
...@@ -34,17 +40,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -34,17 +40,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( #ifdef __fp32__
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __int8__
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2, std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC, NHWC,
...@@ -56,6 +59,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -56,6 +59,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
...@@ -99,23 +103,29 @@ struct DeviceOperationInstanceFactory< ...@@ -99,23 +103,29 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
} }
#ifdef __fp16__
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
#endif
#ifdef __bf16__
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
} }
#endif
} }
return op_ptrs; return op_ptrs;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen ...@@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -343,6 +343,7 @@ struct DeviceOperationInstanceFactory< ...@@ -343,6 +343,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
} }
} }
#ifdef __fp16__
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
...@@ -388,6 +389,8 @@ struct DeviceOperationInstanceFactory< ...@@ -388,6 +389,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
} }
#endif
#ifdef __bf16__
else if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> && else if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<CDataType, ck::bhalf_t>) is_same_v<CDataType, ck::bhalf_t>)
{ {
...@@ -412,6 +415,7 @@ struct DeviceOperationInstanceFactory< ...@@ -412,6 +415,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
} }
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>) is_same_v<CDataType, int8_t>)
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -144,3 +144,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -144,3 +144,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt ...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
// FP16 // FP16
void add_device_normalization_rank_2_1_f16_instances( void add_device_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&); std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&);
...@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances( ...@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances(
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&);
#endif
#ifdef __fp32__
// FP32 // FP32
void add_device_normalization_rank_2_1_f32_instances( void add_device_normalization_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&); std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
...@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances( ...@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances(
void add_device_normalization_rank_5_3_f32_instances( void add_device_normalization_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>) is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
{ {
...@@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f16_instances(op_ptrs); add_device_normalization_rank_5_3_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> && #endif
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>) #ifdef __fp32__
if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
{ {
if constexpr(Rank == 2 && NumReduceDim == 1) if constexpr(Rank == 2 && NumReduceDim == 1)
{ {
...@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f32_instances(op_ptrs); add_device_normalization_rank_5_3_f32_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 2; ...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef __fp16__
// FP16 // FP16
void add_device_pool2d_fwd_nhwc_f16_instances( void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector< std::vector<
...@@ -36,7 +36,8 @@ void add_device_pool2d_fwd_nhwc_f16_instances( ...@@ -36,7 +36,8 @@ void add_device_pool2d_fwd_nhwc_f16_instances(
void add_device_pool2d_fwd_nhwc_index_f16_instances( void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
#endif
#ifdef __fp32__
// FP32 // FP32
void add_device_pool2d_fwd_nhwc_f32_instances( void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector< std::vector<
...@@ -50,7 +51,7 @@ void add_device_pool2d_fwd_nhwc_f32_instances( ...@@ -50,7 +51,7 @@ void add_device_pool2d_fwd_nhwc_f32_instances(
void add_device_pool2d_fwd_nhwc_index_f32_instances( void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
#endif
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs); add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> && #endif
is_same_v<IndexDataType, I32>) #ifdef __fp32__
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{ {
if constexpr(OutputIndex && ReduceOpId == MaxOp) if constexpr(OutputIndex && ReduceOpId == MaxOp)
{ {
...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs); add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3; ...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef __fp16__
// FP16 // FP16
void add_device_pool3d_fwd_ndhwc_f16_instances( void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector< std::vector<
...@@ -36,7 +36,8 @@ void add_device_pool3d_fwd_ndhwc_f16_instances( ...@@ -36,7 +36,8 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
void add_device_pool3d_fwd_ndhwc_index_f16_instances( void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
#endif
#ifdef __fp32__
// FP32 // FP32
void add_device_pool3d_fwd_ndhwc_f32_instances( void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector< std::vector<
...@@ -50,7 +51,7 @@ void add_device_pool3d_fwd_ndhwc_f32_instances( ...@@ -50,7 +51,7 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
void add_device_pool3d_fwd_ndhwc_index_f32_instances( void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
#endif
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs); add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> && #endif
is_same_v<IndexDataType, I32>) #ifdef __fp32__
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{ {
if constexpr(OutputIndex && ReduceOpId == MaxOp) if constexpr(OutputIndex && ReduceOpId == MaxOp)
{ {
...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs); add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef DL_KERNELS
// Layout(A, B, C) = [Col, Row, Row] // Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances( void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
...@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( ...@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
#endif
// Layout(A, B, C) = [Col, Row, Row] // Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
...@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
} }
} }
...@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
} }
} }
...@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
} }
} }
...@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
} }
} }
...@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
\ No newline at end of file
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
...@@ -64,7 +64,7 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( ...@@ -64,7 +64,7 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
PassThrough, PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>& Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -163,12 +163,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -163,12 +163,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs);
} }
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
} }
} }
...@@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
if constexpr(is_same_v<Activation, TanH>) if constexpr(is_same_v<Activation, TanH>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
} }
} }
...@@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
...@@ -63,7 +63,7 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( ...@@ -63,7 +63,7 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
PassThrough, PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>& Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances); instances);
#endif
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector< std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -161,12 +161,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -161,12 +161,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(op_ptrs);
} }
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(op_ptrs);
} }
} }
...@@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
if constexpr(is_same_v<Activation, TanH>) if constexpr(is_same_v<Activation, TanH>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
} }
} }
...@@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perchannel_quantization_int8_instances( void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( ...@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
PassThrough, PassThrough,
Activation_Mul2_Clamp<Relu>>>>& Activation_Mul2_Clamp<Relu>>>>&
instances); instances);
#endif
void add_device_conv2d_xdl_perchannel_quantization_int8_instances( void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC, NHWGC,
...@@ -128,12 +128,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -128,12 +128,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_perchannel_quantization_int8_instances(op_ptrs);
} }
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(op_ptrs);
} }
} }
...@@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment