"vscode:/vscode.git/clone" did not exist on "3941bd1f1507d52f623b82c0b77e0eb640d9b8c3"
Commit 5683ea4e authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents f0831350 dddc2115
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perlayer_quantization_int8_instances( void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( ...@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
PassThrough, PassThrough,
Activation_Mul_Clamp<Relu>>>>& Activation_Mul_Clamp<Relu>>>>&
instances); instances);
#endif
void add_device_conv2d_xdl_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC, NHWGC,
...@@ -125,12 +125,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -125,12 +125,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_perlayer_quantization_int8_instances(op_ptrs);
} }
else if constexpr(is_same_v<Activation, Relu>) else if constexpr(is_same_v<Activation, Relu>)
{ {
#ifdef DL_KERNELS
add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(op_ptrs);
#endif
add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(op_ptrs); add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(op_ptrs);
} }
} }
...@@ -144,3 +148,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -144,3 +148,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -40,7 +40,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -40,7 +40,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> && if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F16>) std::is_same_v<OutDataType, F16>)
{ {
...@@ -65,8 +65,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -65,8 +65,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs); add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs);
} }
} }
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> && #endif
std::is_same_v<OutDataType, F32>) #ifdef __fp32__
if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
{ {
...@@ -89,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -89,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs); add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
add_instance_library(device_batched_gemm_instance set(BATCHED_GEMM_INSTANCES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
) device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
endif()
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
\ No newline at end of file
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_bias_permute_instance add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_gemm_instance add_instance_library(device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_reduce_instance add_instance_library(device_batched_gemm_reduce_instance
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_softmax_gemm_instance add_instance_library(device_batched_gemm_softmax_gemm_instance
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
) )
endif()
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES)
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp endif()
) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
endif()
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES})
add_instance_library(device_contraction_bilinear_instance set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
#float #float
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
endif()
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#double #double
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
) endif()
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
add_instance_library(device_contraction_scale_instance set(DEVICE_CONTRACTION_SCALE_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
#float #float
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
endif()
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#double #double
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
) endif()
add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES})
set(CONV2D_BWD_DATA_INSTANCES) set(CONV2D_BWD_DATA_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp) if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
endif()
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp) if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
endif() endif()
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES}) add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef DL_KERNELS
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( ...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef DL_KERNELS
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( ...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef DL_KERNELS
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( ...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __bf16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -155,3 +155,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -155,3 +155,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
add_instance_library(device_conv2d_fwd_instance set(DEVICE_CONV2D_FWD_INSTANCES)
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp endif()
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES})
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __bf16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -126,3 +126,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -126,3 +126,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -118,3 +118,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -118,3 +118,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __fp32__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -117,3 +117,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( ...@@ -117,3 +117,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment