Commit ddad386b authored by carlushuang's avatar carlushuang
Browse files

add cpu bias_relu_add example

parent 505194d7
...@@ -58,3 +58,4 @@ add_subdirectory(16_gemm_reduce) ...@@ -58,3 +58,4 @@ add_subdirectory(16_gemm_reduce)
add_subdirectory(18_batched_gemm_reduce) add_subdirectory(18_batched_gemm_reduce)
add_subdirectory(cpu_01_conv2d_fwd) add_subdirectory(cpu_01_conv2d_fwd)
add_subdirectory(cpu_02_conv2d_fwd_bias_relu_add)
add_example_executable(example_cpu_conv2d_fwd_bias_relu_add cpu_conv2d_fwd_bias_relu_add.cpp)
target_link_libraries(example_cpu_conv2d_fwd_bias_relu_add PRIVATE device_conv2d_fwd_bias_activation_add_cpu_instance)
set_target_properties(example_cpu_conv2d_fwd_bias_relu_add PROPERTIES LINK_FLAGS "${OMP_LINK_FLAG}")
target_link_libraries(example_cpu_conv2d_fwd_bias_relu_add PRIVATE "${OMP_LIBRARY}")
target_compile_options(example_cpu_conv2d_fwd_bias_relu_add PRIVATE "${OMP_CXX_FLAG}")
#ifndef DEVICE_CONV_FWD_CPU_HPP #ifndef DEVICE_CONV_FWD_CPU_HPP
#define DEVICE_CONV_FWD_CPU_HPP #define DEVICE_CONV_FWD_CPU_HPP
#include <iostream> #include <iostream>
#include "device_base_cpu.hpp" #include "device_base_cpu.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct DeviceConvFwd : public BaseOperator struct DeviceConvFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
const void* p_wei, const void* p_wei,
void* p_out, void* p_out,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0; OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
using DeviceConvFwdPtr = std::unique_ptr< using DeviceConvFwdPtr = std::unique_ptr<
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>; DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device template <typename InElementwiseOperation,
} // namespace cpu typename WeiElementwiseOperation,
} // namespace tensor_operation typename OutElementwiseOperation>
} // namespace ck struct DeviceConvFwdBiasActivationAdd : public BaseOperator
#endif {
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
const void* p_wei,
void* p_out,
const void* p_bias_grid,
const void* p_add_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvFwdBiasActivationAddPtr =
std::unique_ptr<DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>>;
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -148,6 +148,27 @@ struct AddReluAdd ...@@ -148,6 +148,27 @@ struct AddReluAdd
y = _mm256_add_ps(b, x2); y = _mm256_add_ps(b, x2);
} }
float Apply(const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
return b + x2;
}
float4_t Apply(const float4_t& x0, const float4_t& x1, const float4_t& x2) const
{
float4_t a = _mm_add_ps(x0, x1);
float4_t b = _mm_max_ps(a, _mm_setzero_ps());
return _mm_add_ps(b, x2);
}
float8_t Apply(const float8_t& x0, const float8_t& x1, const float8_t& x2) const
{
float8_t a = _mm256_add_ps(x0, x1);
float8_t b = _mm256_max_ps(a, _mm256_setzero_ps());
return _mm256_add_ps(b, x2);
}
static constexpr char* Name() { return "AddReluAdd"; } static constexpr char* Name() { return "AddReluAdd"; }
}; };
......
...@@ -22,3 +22,4 @@ function(add_instance_library INSTANCE_NAME) ...@@ -22,3 +22,4 @@ function(add_instance_library INSTANCE_NAME)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
add_subdirectory(conv2d_fwd_bias_activation_add)
# device_conv2d_fwd_bias_activation_add_cpu_instance
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
)
add_library(device_conv2d_fwd_bias_activation_add_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC)
set_target_properties(device_conv2d_fwd_bias_activation_add_cpu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(device_conv2d_fwd_bias_activation_add_cpu_instance PRIVATE "${OMP_LIBRARY}")
target_compile_options(device_conv2d_fwd_bias_activation_add_cpu_instance PRIVATE "${OMP_CXX_FLAG}")
install(TARGETS device_conv2d_fwd_bias_activation_add_cpu_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_bias_activation_add_cpu_instance)
#include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, true, true, true, false),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment