Commit 5db79de0 authored by carlushuang's avatar carlushuang
Browse files

add a direct bias-relu-add implementation

parent 5024f317
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#define TEST_FUSION_BIAS_RELU 1 #define TEST_FUSION_BIAS_RELU 1
#define TEST_FUSION_BIAS 2 #define TEST_FUSION_BIAS 2
#define TEST_FUSION_BIAS_ADD_RELU 3 #define TEST_FUSION_BIAS_ADD_RELU 3
#define TEST_FUSION TEST_FUSION_BIAS_ADD_RELU #define TEST_FUSION TEST_FUSION_BIAS_RELU_ADD
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
...@@ -171,6 +171,11 @@ void add_device_conv2d_fwd_bias_add_relu_avx2_nhwc_yxck_nhwk_mt( ...@@ -171,6 +171,11 @@ void add_device_conv2d_fwd_bias_add_relu_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddAddRelu>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddAddRelu>>&
instances); instances);
// ------------------ direct-conv nhwc-kcyxk8-nhwk
void add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -623,6 +628,8 @@ int main(int argc, char* argv[]) ...@@ -623,6 +628,8 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c(
conv_ptrs); conv_ptrs);
} }
ck::tensor_operation::cpu::device::device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
#endif #endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK #if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(omp_get_max_threads() > 1) if(omp_get_max_threads() > 1)
......
...@@ -1768,6 +1768,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_ ...@@ -1768,6 +1768,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
static constexpr bool FuseBias = true;
static constexpr bool FuseAdd = true;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN( constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index&, const Index&,
...@@ -2434,6 +2437,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN ...@@ -2434,6 +2437,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
static constexpr bool FuseBias = true;
static constexpr bool FuseAdd = false;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN( constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index&, const Index&,
......
#include <stdlib.h>
#include <utility>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp"
#include "ck/tensor_operation/cpu/device/device_convnd_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
#include "ck/tensor_operation/cpu/element/element_wise_operation_cpu.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXCK8
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using AddRelu = ck::tensor_operation::cpu::element_wise::AddRelu;
using Add = ck::tensor_operation::cpu::element_wise::Add;
using AddAddRelu = ck::tensor_operation::cpu::element_wise::AddAddRelu;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
void add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float, float, float, float, float, PT, PT, AddReluAdd, ConvFwdDefault, 2, 6, 16, false, false, false, true, true, false>({0, 0, 0, DefaultGemmKLoop, LoopOver_MKN}),
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float, float, float, float, float, PT, PT, AddReluAdd, ConvFwdDefault, 2, 6, 16, false, false, false, true, true, false>({0, 0, 0, DefaultGemmKLoop, LoopOver_MNK})
// clang-format on
));
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment