Unverified Commit 3ba004f8 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add GELU activation layer (#2204)

* Add GELU activation layer

* fix some copy-paste leftovers

* fix comment

* use exact faster implementation

* do not use cmath constants
parent f4f8bff9
......@@ -1692,6 +1692,47 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
void gelu (
tensor& dest,
const tensor& src
)
{
const auto d = dest.host();
const auto s = src.host();
for (size_t i = 0; i < src.size(); ++i)
d[i] = 0.5f*s[i]*(1.0f + std::erf(s[i]/sqrt_2));
}
void gelu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
)
{
const float beta = 1.0f / std::sqrt(pi) / sqrt_2;
const auto compute_gradient = [beta](float x)
{
const float cdf = 0.5f*(1.0f + std::erf(x/sqrt_2));
const float pdf = beta*std::exp(-0.5f*x*x);
return cdf + x * pdf;
};
const auto g = grad.host();
const auto s = src.host();
const auto in = gradient_input.host();
if (is_same_object(grad, gradient_input))
{
for (size_t i = 0; i < src.size(); ++i)
g[i] = in[i]*compute_gradient(s[i]);
}
else
{
for (size_t i = 0; i < src.size(); ++i)
g[i] += in[i]*compute_gradient(s[i]);
}
}
// ----------------------------------------------------------------------------------------
void resize_bilinear (
......
......@@ -352,6 +352,19 @@ namespace dlib
const tensor& gradient_input
);
// ----------------------------------------------------------------------------------------
void gelu (
tensor& dest,
const tensor& src
);
void gelu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
// ----------------------------------------------------------------------------------------
void resize_bilinear (
......
......@@ -1479,6 +1479,60 @@ namespace dlib
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_gelu(const float* s, float* d, size_t n)
{
for (auto i : grid_stride_range(0, n))
{
d[i] = s[i] * normcdf(s[i]);
}
}
void gelu (
tensor& dest,
const tensor& src
)
{
launch_kernel(_cuda_gelu, max_jobs(dest.size()), src.device(), dest.device(), src.size());
}
// ----------------------------------------------------------------------------------------
__device__ float gelu_compute_gradient(float x)
{
const float beta = 1.0f / std::sqrt(pi) / sqrt_2;
const float cdf = normcdf(x);
const float pdf = beta*std::exp(-0.5f*x*x);
return cdf + x * pdf;
}
__global__ void _cuda_gelu_gradient_inplace(float* out, const float* s, const float* gi, size_t n)
{
for (auto i : grid_stride_range(0, n))
out[i] = gi[i]*gelu_compute_gradient(s[i]);
}
__global__ void _cuda_gelu_gradient(float* out, const float* s, const float* gi, size_t n)
{
for (auto i : grid_stride_range(0, n))
out[i] += gi[i]*gelu_compute_gradient(s[i]);
}
void gelu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
)
{
float* out = grad.device();
const float* gi = gradient_input.device();
if (out == gi)
launch_kernel(_cuda_gelu_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size());
else
launch_kernel(_cuda_gelu_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d,
......
......@@ -396,6 +396,19 @@ namespace dlib
const tensor& gradient_input
);
// ----------------------------------------------------------------------------------------
void gelu (
tensor& dest,
const tensor& src
);
void gelu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
);
// ----------------------------------------------------------------------------------------
void resize_bilinear (
......
......@@ -965,6 +965,33 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
void gelu (
tensor& dest,
const tensor& src
)
{
#ifdef DLIB_USE_CUDA
cuda::gelu(dest,src);
#else
cpu::gelu(dest,src);
#endif
}
void gelu_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
)
{
#ifdef DLIB_USE_CUDA
cuda::gelu_gradient(grad, src, gradient_input);
#else
cpu::gelu_gradient(grad, src, gradient_input);
#endif
}
// ----------------------------------------------------------------------------------------
void resize_bilinear (
......
......@@ -1518,6 +1518,40 @@ namespace dlib { namespace tt
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
void gelu (
tensor& dest,
const tensor& src
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- #dest.host()[i] == src.host()[i]/2 * (1 + erf(src.host()[i]/sqrt(2))
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void gelu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
ensures
- This function computes the gradient of f() with respect to SRC and stores
it to grad. Moreover, if is_same_object(grad,gradient_input)==true then
the output is assigned to grad, replacing its previous contents.
Otherwise the output is added to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
void resize_bilinear (
......
......@@ -3287,6 +3287,78 @@ namespace dlib
template <typename SUBNET>
using htan = add_layer<htan_, SUBNET>;
// ----------------------------------------------------------------------------------------
class gelu_
{
public:
gelu_()
{
}
template <typename SUBNET>
void setup (const SUBNET& /*sub*/)
{
}
template <typename SUBNET>
void forward(
const SUBNET& sub,
resizable_tensor& data_output
)
{
data_output.copy_size(sub.get_output());
tt::gelu(data_output, sub.get_output());
}
template <typename SUBNET>
void backward(
const tensor& gradient_input,
SUBNET& sub,
tensor&
)
{
tt::gelu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input);
}
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const gelu_& /*item*/, std::ostream& out)
{
serialize("gelu_", out);
}
friend void deserialize(gelu_& /*item*/, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "gelu_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::gelu_.");
}
friend std::ostream& operator<<(std::ostream& out, const gelu_& /*item*/)
{
out << "gelu";
return out;
}
friend void to_xml(const gelu_& /*item*/, std::ostream& out)
{
out << "<gelu/>\n";
}
private:
resizable_tensor params;
};
template <typename SUBNET>
using gelu = add_layer<gelu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class softmax_
......
......@@ -2310,6 +2310,44 @@ namespace dlib
template <typename SUBNET>
using htan = add_layer<htan_, SUBNET>;
// ----------------------------------------------------------------------------------------
class gelu_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. In particular, it defines a gelu layer. Therefore, it
passes its inputs through the function
f(x)= x/2 * (1 + erf(x/sqrt(2))
where f() is applied pointwise across the input tensor.
This is the layer type introduced in the paper:
Dan Hendrycks, Kevin Gimpel. "Gaussian Error Linear Units (GELUs)".
!*/
public:
gelu_(
);
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& data_output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor&);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
interface. Note that this layer doesn't have any parameters, so the tensor
returned by get_layer_params() is always empty.
!*/
};
template <typename SUBNET>
using gelu = add_layer<gelu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class softmax_
......
......@@ -269,6 +269,32 @@ namespace
#endif // DLIB_USE_CUDA
}
void test_gelu()
{
#ifdef DLIB_USE_CUDA
// make sure that cuda::gelu and cpu::gelu return the same results
using namespace dlib::tt;
print_spinner();
const long n = 5;
const long k = 5;
const long nr = 3;
const long nc = 3;
resizable_tensor src(n,k,nr,nc);
tt::tensor_rand rnd;
rnd.fill_uniform(src);
resizable_tensor dest1, dest2;
dest1.copy_size(src);
dest2.copy_size(src);
// initialize to different values in order to make sure the output is actually changed
dest1 = 1;
dest2 = 2;
cuda::gelu(dest1, src);
cpu::gelu(dest2, src);
DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-7, max(abs(mat(dest1) - mat(dest2))));
#endif // DLIB_USE_CUDA
}
void test_batch_normalize()
{
using namespace dlib::tt;
......@@ -1916,6 +1942,12 @@ namespace
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
gelu_ l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
softmax_ l;
......@@ -3967,6 +3999,7 @@ namespace
test_sigmoid();
test_mish();
test_leaky_relu();
test_gelu();
test_batch_normalize();
test_batch_normalize_conv();
test_basic_tensor_ops();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment