Commit 58188d46 authored by Rocking's avatar Rocking
Browse files

Fuse sigmoid after groupnorm

parent aea3b411
add_example_executable(example_groupnorm_blockwise groupnorm_blockwise.cpp) add_example_executable(example_groupnorm_sigmoid groupnorm_sigmoid.cpp)
\ No newline at end of file \ No newline at end of file
...@@ -24,7 +24,7 @@ using GammaDataType = ck::half_t; ...@@ -24,7 +24,7 @@ using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Sigmoid = ck::tensor_operation::element_wise::Sigmoid;
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
...@@ -35,7 +35,7 @@ using DeviceInstance = ...@@ -35,7 +35,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
PassThrough, Sigmoid,
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
...@@ -91,7 +91,7 @@ int main() ...@@ -91,7 +91,7 @@ int main()
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
PassThrough{}); Sigmoid{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -111,11 +111,11 @@ int main() ...@@ -111,11 +111,11 @@ int main()
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
PassThrough>; Sigmoid>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument =
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {N, H, W, G, C}, 1e-6); ref.MakeArgument(x, gamma, beta, host_y, Sigmoid{}, {N, H, W, G, C}, 1e-6);
auto ref_invoker = ref.MakeInvoker(); auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -232,6 +232,21 @@ struct Gelu ...@@ -232,6 +232,21 @@ struct Gelu
} }
}; };
struct Sigmoid
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = 1 / (ck::type_convert<T>(1) + exp(-x));
};
int32_t divider_ = 1;
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -122,6 +122,7 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -122,6 +122,7 @@ struct ReferenceGroupnorm : public device::BaseOperator
AccDataType y = gamma * (x - mean_val) / AccDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) + ck::math::sqrt(arg.epsilon_ + var_val) +
beta; beta;
arg.acc_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y); arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment