Commit 58188d46 authored by Rocking's avatar Rocking
Browse files

Fuse sigmoid after groupnorm

parent aea3b411
add_example_executable(example_groupnorm_blockwise groupnorm_blockwise.cpp)
\ No newline at end of file
add_example_executable(example_groupnorm_sigmoid groupnorm_sigmoid.cpp)
\ No newline at end of file
......@@ -24,7 +24,7 @@ using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Sigmoid = ck::tensor_operation::element_wise::Sigmoid;
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
......@@ -35,7 +35,7 @@ using DeviceInstance =
BetaDataType,
AccDataType,
YDataType,
PassThrough,
Sigmoid,
Rank,
NumReduceDim,
256, // BlockSize
......@@ -91,7 +91,7 @@ int main()
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
PassThrough{});
Sigmoid{});
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
......@@ -111,11 +111,11 @@ int main()
BetaDataType,
YDataType,
AccDataType,
PassThrough>;
Sigmoid>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {N, H, W, G, C}, 1e-6);
ref.MakeArgument(x, gamma, beta, host_y, Sigmoid{}, {N, H, W, G, C}, 1e-6);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
......
......@@ -232,6 +232,21 @@ struct Gelu
}
};
struct Sigmoid
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
"Data type is not supported by this operation!");
y = 1 / (ck::type_convert<T>(1) + exp(-x));
};
int32_t divider_ = 1;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -122,6 +122,7 @@ struct ReferenceGroupnorm : public device::BaseOperator
AccDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.acc_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment