Commit 8169b045 authored by Chao Liu's avatar Chao Liu
Browse files

added conv_c_shuffle+bias_relu

parent adbda385
...@@ -19,7 +19,10 @@ using DeviceConvFwdBiasReluPtr = ...@@ -19,7 +19,10 @@ using DeviceConvFwdBiasReluPtr =
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddRelu>; ck::tensor_operation::element_wise::AddRelu>;
void add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances( void add_device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluPtr>&);
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluPtr>&); std::vector<DeviceConvFwdBiasReluPtr>&);
} // namespace device_conv2d_fwd_bias_activation_instance } // namespace device_conv2d_fwd_bias_activation_instance
...@@ -295,7 +298,10 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -295,7 +298,10 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance:: ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance::
add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances(op_ptrs); add_device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
if(op_ptrs.size() <= 0) if(op_ptrs.size() <= 0)
......
...@@ -18,17 +18,17 @@ using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<ck::tensor_operation::element_wise ...@@ -18,17 +18,17 @@ using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<ck::tensor_operation::element_wise
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_fp16_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances(
std::vector<DeviceConvFwdNoOpPtr>&); std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_fp16_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances(
std::vector<DeviceConvFwdNoOpPtr>&); std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_fp16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdNoOpPtr>&); std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv2d_fwd_instance } // namespace device_conv2d_fwd_instance
...@@ -143,23 +143,27 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -143,23 +143,27 @@ void profile_conv_fwd_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp32_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
#if 0 // debug
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_fp16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_fp16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances(conv_ptrs);
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_fp16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances(conv_ptrs);
#endif
#if 1 // debug
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_fp16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
#endif
} }
if(conv_ptrs.size() <= 0) if(conv_ptrs.size() <= 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment