Commit b892a14a authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed dtype

parent adfcb029
...@@ -2,11 +2,14 @@ if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) ...@@ -2,11 +2,14 @@ if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations)
endif()
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_operations)
endif() endif()
if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations)
endif() endif()
...@@ -16,6 +16,7 @@ namespace ck { ...@@ -16,6 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv1d forward, GNWC/GKXC/GNWK // grouped conv1d forward, GNWC/GKXC/GNWK
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
...@@ -32,6 +33,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( ...@@ -32,6 +33,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
...@@ -47,6 +49,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( ...@@ -47,6 +49,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
...@@ -62,6 +65,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( ...@@ -62,6 +65,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
...@@ -77,6 +81,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( ...@@ -77,6 +81,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv2d forward, GNHWC/GKYXC/GNHWK // grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
...@@ -93,6 +98,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( ...@@ -93,6 +98,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -108,6 +114,7 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -108,6 +114,7 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -123,38 +130,7 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -123,38 +130,7 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -169,23 +145,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -169,23 +145,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#ifdef DL_KERNELS
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -201,22 +162,7 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( ...@@ -201,22 +162,7 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
...@@ -233,6 +179,7 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( ...@@ -233,6 +179,7 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -248,6 +195,7 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -248,6 +195,7 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -263,6 +211,7 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -263,6 +211,7 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
...@@ -279,6 +228,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( ...@@ -279,6 +228,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
...@@ -294,6 +244,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( ...@@ -294,6 +244,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
...@@ -309,6 +260,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( ...@@ -309,6 +260,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
...@@ -324,6 +276,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( ...@@ -324,6 +276,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
...@@ -340,6 +293,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( ...@@ -340,6 +293,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
...@@ -388,6 +342,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( ...@@ -388,6 +342,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
...@@ -404,6 +359,71 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( ...@@ -404,6 +359,71 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -477,7 +497,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -477,7 +497,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
#endif #endif
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>) is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{ {
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
...@@ -485,22 +506,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -485,22 +506,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
#ifdef DL_KERNELS }
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
#endif #endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
#endif
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
} }
#endif #endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>) is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
...@@ -508,37 +542,52 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -508,37 +542,52 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_INT8 #ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
} }
#endif #endif
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>) is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{ {
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
#ifdef DL_KERNELS }
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
#endif #endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
#ifdef DL_KERNELS }
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
#endif #endif
#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS))
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BDF16 #ifdef CK_ENABLE_BDF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>) is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
...@@ -547,7 +596,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -547,7 +596,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
#endif #endif
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>) is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{ {
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
...@@ -579,7 +629,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -579,7 +629,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
#endif #endif
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>) is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{ {
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
...@@ -588,6 +639,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -588,6 +639,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>) is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment