Commit b93575ca authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 54df59bf c8a8385f
...@@ -102,9 +102,10 @@ struct FillMonotonicSeq ...@@ -102,9 +102,10 @@ struct FillMonotonicSeq
} }
template <typename ForwardRange> template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<decltype( auto operator()(ForwardRange&& range) const
std::declval<const FillMonotonicSeq&>()(std::begin(std::forward<ForwardRange>(range)), -> std::void_t<decltype(std::declval<const FillMonotonicSeq&>()(
std::end(std::forward<ForwardRange>(range))))> std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{ {
(*this)(std::begin(std::forward<ForwardRange>(range)), (*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))); std::end(std::forward<ForwardRange>(range)));
......
add_instance_library(device_batched_gemm_instance set(BATCHED_GEMM_INSTANCES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
) device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
endif()
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
\ No newline at end of file
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_bias_permute_instance add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_gemm_instance add_instance_library(device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_reduce_instance add_instance_library(device_batched_gemm_reduce_instance
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_softmax_gemm_instance add_instance_library(device_batched_gemm_softmax_gemm_instance
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
) )
endif()
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES)
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp endif()
) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
endif()
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES})
add_instance_library(device_contraction_bilinear_instance set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
#float #float
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
endif()
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#double #double
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
) endif()
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
add_instance_library(device_contraction_scale_instance set(DEVICE_CONTRACTION_SCALE_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
#float #float
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
endif()
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#double #double
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
) endif()
add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES})
set(CONV2D_BWD_DATA_INSTANCES) set(CONV2D_BWD_DATA_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp) if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
endif()
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp) if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
endif() endif()
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES}) add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef DL_KERNELS
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( ...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef DL_KERNELS
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( ...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef DL_KERNELS
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( ...@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef CK_ENABLE_BF16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -155,3 +155,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -155,3 +155,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__ #ifdef CK_ENABLE_INT8
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
add_instance_library(device_conv2d_fwd_instance set(DEVICE_CONV2D_FWD_INSTANCES)
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp endif()
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES})
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef CK_ENABLE_BF16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -126,3 +126,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -126,3 +126,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef CK_ENABLE_FP16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -118,3 +118,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -118,3 +118,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef CK_ENABLE_FP32
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -117,3 +117,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( ...@@ -117,3 +117,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment