Commit f8b551da authored by carlushuang's avatar carlushuang
Browse files

add bias_relu, bias fusion

parent bfa4c686
...@@ -37,6 +37,8 @@ template <typename InDataType, ...@@ -37,6 +37,8 @@ template <typename InDataType,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer, bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM> bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
...@@ -607,19 +609,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -607,19 +609,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization>; ConvForwardSpecialization>;
using CThreadwiseCopy = static constexpr auto GetCThreadwiseCopy()
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< {
CDataType, constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
C0DataType, if constexpr(FuseBias && FuseAdd)
C1DataType, {
CDataType, return ck::cpu::
CGridDesc, ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
C0GridDesc, CDataType,
C1GridDesc, C0DataType,
decltype(GetOutputBlockDescriptor()), C1DataType,
OutElementwiseOperation, CDataType,
!UseCLocalBuffer, CGridDesc,
BiasAlongGemmM>; C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType, ADataType, // InDataType,
......
...@@ -37,6 +37,8 @@ template <typename InDataType, ...@@ -37,6 +37,8 @@ template <typename InDataType,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer, bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM> bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
...@@ -584,19 +586,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -584,19 +586,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization>; ConvForwardSpecialization>;
using CThreadwiseCopy = static constexpr auto GetCThreadwiseCopy()
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< {
CDataType, constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
C0DataType, if constexpr(FuseBias && FuseAdd)
C1DataType, {
CDataType, return ck::cpu::
CGridDesc, ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
C0GridDesc, CDataType,
C1GridDesc, C0DataType,
decltype(GetOutputBlockDescriptor()), C1DataType,
OutElementwiseOperation, CDataType,
!UseCLocalBuffer, CGridDesc,
BiasAlongGemmM>; C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType, ADataType, // InDataType,
...@@ -954,7 +988,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -954,7 +988,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
return "G"; return "G";
}; };
// clang-format off // clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXCK8" << "DFwd_BAA_Avx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
......
...@@ -36,6 +36,8 @@ template <typename InDataType, ...@@ -36,6 +36,8 @@ template <typename InDataType,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer, bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM> bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
...@@ -580,19 +582,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -580,19 +582,51 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization>; ConvForwardSpecialization>;
using CThreadwiseCopy = static constexpr auto GetCThreadwiseCopy()
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< {
CDataType, constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
C0DataType, if constexpr(FuseBias && FuseAdd)
C1DataType, {
CDataType, return ck::cpu::
CGridDesc, ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
C0GridDesc, CDataType,
C1GridDesc, C0DataType,
decltype(GetOutputBlockDescriptor()), C1DataType,
OutElementwiseOperation, CDataType,
!UseCLocalBuffer, CGridDesc,
BiasAlongGemmM>; C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType, ADataType, // InDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment