Commit d1006d46 authored by Jing Zhang's avatar Jing Zhang
Browse files

removed AddBias

parent c4411860
...@@ -30,7 +30,7 @@ using Row = ck::tensor_layout::gemm::RowMajor; ...@@ -30,7 +30,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddBias = ck::tensor_operation::element_wise::AddBias; using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
...@@ -49,7 +49,7 @@ using ELayout = Row; ...@@ -49,7 +49,7 @@ using ELayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = AddBias; using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding;
......
...@@ -36,6 +36,13 @@ struct Add ...@@ -36,6 +36,13 @@ struct Add
y = x0 + type_convert<half_t>(x1); y = x0 + type_convert<half_t>(x1);
}; };
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 + x1);
};
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
......
...@@ -57,12 +57,6 @@ struct PassThrough ...@@ -57,12 +57,6 @@ struct PassThrough
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
...@@ -126,34 +120,6 @@ struct PassThrough ...@@ -126,34 +120,6 @@ struct PassThrough
} }
}; };
struct AddBias
{
template <typename E, typename C, typename D0>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0) const;
template <>
__host__ __device__ void
operator()<ck::half_t, float, float>(ck::half_t& e, const float& c, const float& d0) const
{
e = c + d0;
}
template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t, float>(ck::half_t& e,
const ck::half_t& c,
const float& d0) const
{
e = c + d0;
}
template <>
__host__ __device__ void
operator()<float, float, float>(float& e, const float& c, const float& d0) const
{
e = c + d0;
}
};
struct UnaryConvert struct UnaryConvert
{ {
template <typename Y, typename X> template <typename Y, typename X>
......
...@@ -876,8 +876,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -876,8 +876,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t StrideE, const index_t StrideE,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto p_a_grid = reinterpret_cast<const ABDataType*>(p_a_grid_); const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const ABDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_); const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
...@@ -902,8 +902,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -902,8 +902,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
......
...@@ -101,7 +101,6 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; ...@@ -101,7 +101,6 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
using AddBias = ck::tensor_operation::element_wise::AddBias;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
...@@ -28,7 +28,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( ...@@ -28,7 +28,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>>>& instances); Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
...@@ -41,7 +41,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( ...@@ -41,7 +41,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>>>& instances); Add>>>& instances);
// fp32_output // fp32_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
...@@ -55,7 +55,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances( ...@@ -55,7 +55,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>>>& instances); Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
...@@ -68,7 +68,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances( ...@@ -68,7 +68,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>>>& instances); Add>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory< ...@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
EDataType, EDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>> Add>>
{ {
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout, using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout, BLayout,
...@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory< ...@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
EDataType, EDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>; Add>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -31,7 +31,7 @@ using D0Layout = Row; ...@@ -31,7 +31,7 @@ using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias; using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
...@@ -31,7 +31,7 @@ using D0Layout = Row; ...@@ -31,7 +31,7 @@ using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias; using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
...@@ -31,7 +31,7 @@ using D0Layout = Row; ...@@ -31,7 +31,7 @@ using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias; using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
...@@ -31,7 +31,7 @@ using D0Layout = Row; ...@@ -31,7 +31,7 @@ using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::AddBias; using Add = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment