Commit f8b551da authored by carlushuang's avatar carlushuang
Browse files

add bias_relu, bias fusion

parent bfa4c686
......@@ -37,6 +37,8 @@ template <typename InDataType,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
......@@ -607,19 +609,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
!UseBLocalBuffer,
ConvForwardSpecialization>;
using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>;
static constexpr auto GetCThreadwiseCopy()
{
constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
if constexpr(FuseBias && FuseAdd)
{
return ck::cpu::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType,
......
......@@ -37,6 +37,8 @@ template <typename InDataType,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
......@@ -584,19 +586,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
!UseBLocalBuffer,
ConvForwardSpecialization>;
using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>;
static constexpr auto GetCThreadwiseCopy()
{
constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
if constexpr(FuseBias && FuseAdd)
{
return ck::cpu::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType,
......@@ -954,7 +988,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
......
......@@ -36,6 +36,8 @@ template <typename InDataType,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
......@@ -580,19 +582,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
!UseBLocalBuffer,
ConvForwardSpecialization>;
using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>;
static constexpr auto GetCThreadwiseCopy()
{
constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
if constexpr(FuseBias && FuseAdd)
{
return ck::cpu::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment