Commit c7c47fd7 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Merge branch 'develop' into bwroblew/dpp8

parents f8eb91d7 578142db
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __bf16__
// conv1d backward data // conv1d backward data
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1, std::vector<std::unique_ptr<DeviceConvBwdData<1,
...@@ -29,16 +29,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( ...@@ -29,16 +29,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp16__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __fp32__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1, std::vector<std::unique_ptr<DeviceConvBwdData<1,
...@@ -52,6 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( ...@@ -52,6 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef __bf16__
// conv2d backward data // conv2d backward data
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -64,7 +68,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -64,7 +68,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp16__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -76,7 +81,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -76,7 +81,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp32__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -88,6 +94,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( ...@@ -88,6 +94,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -101,6 +108,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -101,6 +108,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef DL_KERNELS
#ifdef __fp16__
// conv2d dl // conv2d dl
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -113,7 +122,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( ...@@ -113,7 +122,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp32__
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC, NHWC,
...@@ -125,6 +135,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( ...@@ -125,6 +135,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2, std::vector<std::unique_ptr<DeviceConvBwdData<2,
...@@ -138,6 +149,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( ...@@ -138,6 +149,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#endif
#ifdef __bf16__
// conv3d backward data // conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -150,7 +163,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( ...@@ -150,7 +163,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp16__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC, NDHWC,
...@@ -162,7 +176,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( ...@@ -162,7 +176,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __fp32__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC, NDHWC,
...@@ -174,6 +189,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( ...@@ -174,6 +189,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
#ifdef __int8__ #ifdef __int8__
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3, std::vector<std::unique_ptr<DeviceConvBwdData<3,
...@@ -229,20 +245,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -229,20 +245,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && #ifdef __fp16__
is_same_v<OutDataType, half_t>) if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && #endif
is_same_v<WeiDataType, ck::bhalf_t> && #ifdef __bf16__
is_same_v<OutDataType, ck::bhalf_t>) if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
} }
...@@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#endif
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && #ifdef __fp16__
is_same_v<OutDataType, half_t>) if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
#endif
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && #endif
is_same_v<WeiDataType, ck::bhalf_t> && #ifdef __bf16__
is_same_v<OutDataType, ck::bhalf_t>) if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
#endif
} }
#endif #endif
} }
...@@ -286,20 +314,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw ...@@ -286,20 +314,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && #ifdef __fp16__
is_same_v<OutDataType, half_t>) if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && #endif
is_same_v<WeiDataType, ck::bhalf_t> && #ifdef __bf16__
is_same_v<OutDataType, ck::bhalf_t>) if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
} }
......
...@@ -18,11 +18,17 @@ namespace device { ...@@ -18,11 +18,17 @@ namespace device {
namespace instance { namespace instance {
// conv2d forward // conv2d forward
#ifdef __fp16__
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef __bf16__
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2, std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC, NHWC,
...@@ -34,17 +40,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -34,17 +40,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( #ifdef __fp32__
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>& DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#ifdef __int8__
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2, std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC, NHWC,
...@@ -56,6 +59,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -56,6 +59,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
...@@ -99,23 +103,29 @@ struct DeviceOperationInstanceFactory< ...@@ -99,23 +103,29 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
} }
#ifdef __fp16__
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
#endif
#ifdef __bf16__
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
} }
#endif
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
} }
#endif
} }
return op_ptrs; return op_ptrs;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen ...@@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -363,6 +363,7 @@ struct DeviceOperationInstanceFactory< ...@@ -363,6 +363,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
} }
} }
#ifdef __fp16__
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
...@@ -412,6 +413,8 @@ struct DeviceOperationInstanceFactory< ...@@ -412,6 +413,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
} }
#endif
#ifdef __bf16__
else if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> && else if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<CDataType, ck::bhalf_t>) is_same_v<CDataType, ck::bhalf_t>)
{ {
...@@ -436,6 +439,7 @@ struct DeviceOperationInstanceFactory< ...@@ -436,6 +439,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
} }
} }
#endif
#ifdef __int8__ #ifdef __int8__
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> && else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>) is_same_v<CDataType, int8_t>)
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -144,3 +144,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -144,3 +144,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt ...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment