Unverified Commit 1b7c7a64 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add Clipped ReLU and ELU activations (#2285)



* wip: add apis for clipped_relu and elu, and layer implementation for clipped_relu

* add tensor_tools documentation

* add cpu implementations for new activations

* add elu layer

* use upperbound and lowerbound for clipped_relu

* fix clipped_relu gradient due to wrong variable naming

* fix elu_gradient due to wrong variable naming

* fix elu_gradient documentation

* add documentation

* WIP: add test_layer cases for clipped_relu and elu

For some reason that I can't see, ELU is failing...

* add clipped_relu and elu tests... cuda elu layer still does not work

* fix spacing

* add custom cuda implementation for elu_gradient (this one works)

* Revert "add custom cuda implementation for elu_gradient (this one works)"

This reverts commit 446dd803964cc6ecca598ddf6688e5d89ca0b112.

* Revert "Revert "add custom cuda implementation for elu_gradient (this one works)""

This reverts commit 0b615f50081d0d90e71d502b6767fcb6ba62f28a.

* add comment about custom elu gradient implementation

* add gradient tests, restore cudnn elu gradient

* re add custom elu gradient implementation

* update docs

* use own cuda implementation for clipped_relu and elu
Co-authored-by: default avatarDavis E. King <davis@dlib.net>
parent 0ffe9c4c
...@@ -1863,6 +1863,98 @@ namespace dlib ...@@ -1863,6 +1863,98 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
void clipped_relu (
tensor& dest,
const tensor& src,
const float ceiling
)
{
dest = upperbound(lowerbound(mat(src), 0), ceiling);
}
void clipped_relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float ceiling
)
{
const auto out = grad.host();
const auto in = dest.host();
const auto gi = gradient_input.host();
if (is_same_object(grad, gradient_input))
{
for (size_t i = 0; i < dest.size(); ++i)
{
if (in[i] > 0 && in[i] < ceiling)
out[i] = gi[i];
else
out[i] = 0;
}
}
else
{
for (size_t i = 0; i < dest.size(); ++i)
{
if (in[i] > 0 && in[i] < ceiling)
out[i] += gi[i];
}
}
}
// ----------------------------------------------------------------------------------------
void elu (
tensor& dest,
const tensor& src,
const float alpha
)
{
const auto d = dest.host();
const auto s = src.host();
for (size_t i = 0; i < src.size(); ++i)
{
if (s[i] > 0)
d[i] = s[i];
else
d[i] = alpha * (std::exp(s[i]) - 1.0f);
}
}
void elu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float alpha
)
{
const auto out = grad.host();
const auto in = dest.host();
const auto gi = gradient_input.host();
if (is_same_object(grad, gradient_input))
{
for (size_t i = 0; i < dest.size(); ++i)
{
if (in[i] > 0)
out[i] = gi[i];
else
out[i] = alpha * std::exp(in[i]) * gi[i];
}
}
else
{
for (size_t i = 0; i < dest.size(); ++i)
{
if (in[i] > 0)
out[i] += gi[i];
else
out[i] += alpha * std::exp(in[i]) * gi[i];
}
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void gelu ( void gelu (
......
...@@ -376,6 +376,36 @@ namespace dlib ...@@ -376,6 +376,36 @@ namespace dlib
const tensor& gradient_input const tensor& gradient_input
); );
// ------------------------------------------------------------------------------------
void clipped_relu (
tensor& dest,
const tensor& src,
const float ceiling
);
void clipped_relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float ceiling
);
// ------------------------------------------------------------------------------------
void elu (
tensor& dest,
const tensor& src,
const float alpha
);
void elu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float alpha
);
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void gelu ( void gelu (
......
...@@ -1480,6 +1480,127 @@ namespace dlib ...@@ -1480,6 +1480,127 @@ namespace dlib
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size()); launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
} }
// ----------------------------------------------------------------------------------------
__global__ void _cuda_clipped_relu(const float* s, float* d, size_t n, const float alpha)
{
for (auto i : grid_stride_range(0, n))
{
if (s[i] > 0 && s[i] < alpha)
d[i] = s[i];
else
d[i] = 0.f;
}
}
void clipped_relu (
tensor& dest,
const tensor &src,
const float alpha
)
{
launch_kernel(_cuda_clipped_relu, max_jobs(dest.size()),
src.device(), dest.device(), src.size(), alpha);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_clipped_relu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float alpha)
{
for (auto i : grid_stride_range(0, n))
{
if (s[i] > 0 && s[i] < alpha)
out[i] = gi[i];
else
out[i] = 0.f;
}
}
__global__ void _cuda_clipped_relu_gradient(float* out, const float* s, const float* gi, size_t n, const float alpha)
{
for (auto i : grid_stride_range(0, n))
{
if (s[i] > 0 && s[i] < alpha)
out[i] += gi[i];
}
}
void clipped_relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float alpha
)
{
float* out = grad.device();
const float* gi = gradient_input.device();
if (out == gi)
launch_kernel(_cuda_clipped_relu_gradient_inplace, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha);
else
launch_kernel(_cuda_clipped_relu_gradient, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_elu(const float* s, float* d, size_t n, const float alpha)
{
for (auto i : grid_stride_range(0, n))
{
if (s[i] > 0)
d[i] = s[i];
else
d[i] = alpha * (std::exp(s[i]) - 1.0f);
}
}
void elu (
tensor& dest,
const tensor &src,
const float alpha
)
{
launch_kernel(_cuda_elu, max_jobs(dest.size()), src.device(), dest.device(), src.size(), alpha);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_elu_gradient_inplace(float* out, const float* s, const float* gi, size_t n, const float alpha)
{
for (auto i : grid_stride_range(0, n))
{
if (s[i] > 0)
out[i] = gi[i];
else
out[i] = alpha * std::exp(s[i]) * gi[i];
}
}
__global__ void _cuda_elu_gradient(float* out, const float* s, const float* gi, size_t n, const float alpha)
{
for (auto i : grid_stride_range(0, n))
{
if (s[i] > 0)
out[i] += gi[i];
else
out[i] += alpha * std::exp(s[i]) * gi[i];
}
}
void elu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float alpha
)
{
float* out = grad.device();
const float* gi = gradient_input.device();
if (out == gi)
launch_kernel(_cuda_elu_gradient_inplace, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha);
else
launch_kernel(_cuda_elu_gradient, max_jobs(grad.size()), out, dest.device(), gi, grad.size(), alpha);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_gelu(const float* s, float* d, size_t n) __global__ void _cuda_gelu(const float* s, float* d, size_t n)
......
...@@ -420,6 +420,36 @@ namespace dlib ...@@ -420,6 +420,36 @@ namespace dlib
const tensor& gradient_input const tensor& gradient_input
); );
// ----------------------------------------------------------------------------------------
void clipped_relu (
tensor& dest,
const tensor& src,
const float coef
);
void clipped_relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float ceiling
);
// ------------------------------------------------------------------------------------
void elu (
tensor& dest,
const tensor& src,
const float alpha
);
void elu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float alpha
);
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void gelu ( void gelu (
......
...@@ -131,11 +131,11 @@ namespace dlib ...@@ -131,11 +131,11 @@ namespace dlib
cudnn_activation_descriptor( cudnn_activation_descriptor(
cudnnActivationMode_t mode, cudnnActivationMode_t mode,
cudnnNanPropagation_t reluNanOpt, cudnnNanPropagation_t reluNanOpt,
double reluCeiling double coef
) )
{ {
CHECK_CUDNN(cudnnCreateActivationDescriptor(&handle)); CHECK_CUDNN(cudnnCreateActivationDescriptor(&handle));
CHECK_CUDNN(cudnnSetActivationDescriptor(handle, mode, reluNanOpt, reluCeiling)); CHECK_CUDNN(cudnnSetActivationDescriptor(handle, mode, reluNanOpt, coef));
} }
~cudnn_activation_descriptor() ~cudnn_activation_descriptor()
...@@ -1668,6 +1668,7 @@ namespace dlib ...@@ -1668,6 +1668,7 @@ namespace dlib
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
} }
} }
......
...@@ -999,6 +999,64 @@ namespace dlib { namespace tt ...@@ -999,6 +999,64 @@ namespace dlib { namespace tt
#endif #endif
} }
// ----------------------------------------------------------------------------------------
void clipped_relu (
tensor& dest,
const tensor& src,
const float ceiling
)
{
#ifdef DLIB_USE_CUDA
cuda::clipped_relu(dest, src, ceiling);
#else
cpu::clipped_relu(dest, src, ceiling);
#endif
}
void clipped_relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float ceiling
)
{
#ifdef DLIB_USE_CUDA
cuda::clipped_relu_gradient(grad, dest, gradient_input, ceiling);
#else
cpu::clipped_relu_gradient(grad, dest, gradient_input, ceiling);
#endif
}
// ----------------------------------------------------------------------------------------
void elu (
tensor& dest,
const tensor& src,
const float alpha
)
{
#ifdef DLIB_USE_CUDA
cuda::elu(dest, src, alpha);
#else
cpu::elu(dest, src, alpha);
#endif
}
void elu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float alpha
)
{
#ifdef DLIB_USE_CUDA
cuda::elu_gradient(grad, dest, gradient_input, alpha);
#else
cpu::elu_gradient(grad, dest, gradient_input, alpha);
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void gelu ( void gelu (
......
...@@ -1542,6 +1542,85 @@ namespace dlib { namespace tt ...@@ -1542,6 +1542,85 @@ namespace dlib { namespace tt
is_same_object(grad, gradient_input)==true is_same_object(grad, gradient_input)==true
!*/ !*/
// ----------------------------------------------------------------------------------------
void clipped_relu (
tensor& dest,
const tensor& src,
const float ceiling
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- #dest.host()[i] == std::min(std::max(src.host()[i], 0), ceiling)
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void clipped_relu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float ceiling
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
ensures
- Recalling that dest is the output of clipped_relu(dest,SRC,ceiling) for
some SRC tensor, let f(SRC) == dot(gradient_input,dest). Then this
function computes the gradient of f() with respect to SRC and stores it
to grad. Moreover, if is_same_object(grad,gradient_input)==true then the
output is assigned to grad, replacing its previous contents. Otherwise
the output is added to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
void elu (
tensor& dest,
const tensor& src,
const float alpha
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- if (src.host()[i] > 0) then
- #dest.host()[i] == src.host()[i]
- else
- #dest.host()[i] == alpha * (std::exp(src.host()[i]) - 1)
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void elu_gradient (
tensor& grad,
const tensor& dest,
const tensor& gradient_input,
const float alpha
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
ensures
- Recalling that dest is the output of elu(dest,SRC) for some SRC tensor,
let f(SRC) == dot(gradient_input,dest). Then this function computes the
gradient of f() with respect to SRC and stores it to grad. Moreover, if
is_same_object(grad,gradient_input)==true then the output is assigned to
grad, replacing its previous contents. Otherwise the output is added to
grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void gelu ( void gelu (
......
...@@ -3458,6 +3458,174 @@ namespace dlib ...@@ -3458,6 +3458,174 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using htan = add_layer<htan_, SUBNET>; using htan = add_layer<htan_, SUBNET>;
// ----------------------------------------------------------------------------------------
class clipped_relu_
{
public:
clipped_relu_(
const float ceiling_ = 6.0f
) : ceiling(ceiling_)
{
}
float get_ceiling(
) const {
return ceiling;
}
template <typename SUBNET>
void setup (const SUBNET& /*sub*/)
{
}
template <typename SUBNET>
void forward(
SUBNET& sub,
resizable_tensor& data_output
)
{
data_output.copy_size(sub.get_output());
tt::clipped_relu(data_output, sub.get_output(), ceiling);
}
template <typename SUBNET>
void backward(
const tensor& gradient_input,
SUBNET& sub,
tensor&
)
{
tt::clipped_relu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input, ceiling);
}
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const clipped_relu_& item, std::ostream& out)
{
serialize("clipped_relu_", out);
serialize(item.ceiling, out);
}
friend void deserialize(clipped_relu_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "clipped_relu_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::clipped_relu_.");
deserialize(item.ceiling, in);
}
friend std::ostream& operator<<(std::ostream& out, const clipped_relu_& item)
{
out << "clipped_relu\t("
<< "ceiling=" << item.ceiling
<< ")";
return out;
}
friend void to_xml(const clipped_relu_& item, std::ostream& out)
{
out << "<clipped_relu ceiling='" << item.ceiling << "'/>\n";
}
private:
resizable_tensor params;
float ceiling;
};
template <typename SUBNET>
using clipped_relu = add_layer<clipped_relu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class elu_
{
public:
elu_(
const float alpha_ = 1.0f
) : alpha(alpha_)
{
}
float get_alpha(
) const {
return alpha;
}
template <typename SUBNET>
void setup (const SUBNET& /*sub*/)
{
}
template <typename SUBNET>
void forward(
SUBNET& sub,
resizable_tensor& data_output
)
{
data_output.copy_size(sub.get_output());
tt::elu(data_output, sub.get_output(), alpha);
}
template <typename SUBNET>
void backward(
const tensor& gradient_input,
SUBNET& sub,
tensor&
)
{
tt::elu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input, alpha);
}
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const elu_& item, std::ostream& out)
{
serialize("elu_", out);
serialize(item.alpha, out);
}
friend void deserialize(elu_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "elu_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::elu_.");
deserialize(item.alpha, in);
}
friend std::ostream& operator<<(std::ostream& out, const elu_& item)
{
out << "elu\t ("
<< "alpha=" << item.alpha
<< ")";
return out;
}
friend void to_xml(const elu_& item, std::ostream& out)
{
out << "<elu alpha='" << item.alpha << "'/>\n";
}
private:
resizable_tensor params;
float alpha;
};
template <typename SUBNET>
using elu = add_layer<elu_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class gelu_ class gelu_
......
...@@ -2394,6 +2394,9 @@ namespace dlib ...@@ -2394,6 +2394,9 @@ namespace dlib
passes its inputs through the function passes its inputs through the function
f(x)= x*tanh(log(1+exp(x))) f(x)= x*tanh(log(1+exp(x)))
where f() is applied pointwise across the input tensor. where f() is applied pointwise across the input tensor.
This is the layer type introduced in the paper:
Diganta Misra. "Mish: A Self Regularized Non-Monotonic Activation Function"
!*/ !*/
public: public:
...@@ -2453,6 +2456,103 @@ namespace dlib ...@@ -2453,6 +2456,103 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using htan = add_layer<htan_, SUBNET>; using htan = add_layer<htan_, SUBNET>;
// ----------------------------------------------------------------------------------------
class clipped_relu_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. In particular, it defines a clipped version of the relu layer.
Therefore, it passes its inputs through the function
f(x) = min(max(x, 0), ceiling)
where f() is applied pointwise across the input tensor and ceiling is a
non-learned scalar.
!*/
public:
clipped_relu_(
const float ceiling = 6.0f
);
/*!
ensures
- the ceiling parameter will be initialized with the ceiling value
!*/
float get_ceiling() const;
/*!
ensures
- returns the celiling parameter of the clipped_relu
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
void forward_inplace(const tensor& input, tensor& output);
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
interface. Note that this layer doesn't have any parameters, so the tensor
returned by get_layer_params() is always empty.
!*/
};
template <typename SUBNET>
using clipped_relu = add_layer<clipped_relu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class elu_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. In particular, it defines an exponential linear unit.
Therefore, it passes its inputs through the function
f(x) = x>0 ? x : alpha*(exp(x)-1)
where f() is applied pointwise across the input tensor and alpha is a
non-learned scalar.
This is the layer type introduced in the paper:
Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter.
"Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)".
!*/
public:
elu_(
const float alpha = 1.0f
);
/*!
ensures
- the alpha parameter will be initialized with the alpha value
!*/
float get_alpha() const;
/*!
ensures
- returns the alpha parameter of the elu
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
void forward_inplace(const tensor& input, tensor& output);
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
interface. Note that this layer doesn't have any parameters, so the tensor
returned by get_layer_params() is always empty.
!*/
};
template <typename SUBNET>
using elu = add_layer<elu_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class gelu_ class gelu_
......
...@@ -256,16 +256,93 @@ namespace ...@@ -256,16 +256,93 @@ namespace
resizable_tensor src(n, k, nr, nc); resizable_tensor src(n, k, nr, nc);
tt::tensor_rand rnd; tt::tensor_rand rnd;
rnd.fill_uniform(src); rnd.fill_uniform(src);
resizable_tensor dest1, dest2; resizable_tensor dest_cuda, dest_cpu;
dest1.copy_size(src); dest_cuda.copy_size(src);
dest2.copy_size(src); dest_cpu.copy_size(src);
// initialize to different values in order to make sure the output is actually changed // initialize to different values in order to make sure the output is actually changed
dest1 = 1; dest_cuda = 1;
dest2 = 2; dest_cpu = 2;
cuda::leaky_relu(dest1, src, alpha); cuda::leaky_relu(dest_cuda, src, alpha);
cpu::leaky_relu(dest2, src, alpha); cpu::leaky_relu(dest_cpu, src, alpha);
DLIB_TEST_MSG(max(abs(mat(dest1) - mat(dest2))) < 1e-7, max(abs(mat(dest1) - mat(dest2)))); DLIB_TEST_MSG(max(abs(mat(dest_cuda) - mat(dest_cpu))) < 1e-7, max(abs(mat(dest_cuda) - mat(dest_cpu))));
#endif // DLIB_USE_CUDA
}
void test_clipped_relu()
{
#ifdef DLIB_USE_CUDA
using namespace dlib::tt;
print_spinner();
const long n = 5;
const long k = 5;
const long nr = 3;
const long nc = 3;
const float ceiling = 6.0f;
resizable_tensor src(n, k, nr, nc);
tt::tensor_rand rnd;
rnd.fill_uniform(src);
resizable_tensor dest_cuda, dest_cpu;
dest_cuda.copy_size(src);
dest_cpu.copy_size(src);
// initialize to different values in order to make sure the output is actually changed
dest_cuda = 1;
dest_cpu = 2;
cuda::clipped_relu(dest_cuda, src, ceiling);
cpu::clipped_relu(dest_cpu, src, ceiling);
auto error = max(abs(mat(dest_cuda) - mat(dest_cpu)));
DLIB_TEST_MSG(error < 1e-7, "error: " << error);
// test gradients
resizable_tensor grad_cuda, grad_cpu, grad_input;
grad_cuda.copy_size(src);
grad_cpu.copy_size(src);
grad_input.copy_size(src);
rnd.fill_uniform(grad_input);
grad_cuda = 0;
grad_cpu = 0;
cuda::clipped_relu_gradient(grad_cuda, dest_cuda, grad_input, ceiling);
cpu::clipped_relu_gradient(grad_cpu, dest_cpu, grad_input, ceiling);
error = max(abs(mat(grad_cuda) - mat(grad_cpu)));
DLIB_TEST_MSG(error < 1e-7, "error: " << error);
#endif // DLIB_USE_CUDA
}
void test_elu()
{
#ifdef DLIB_USE_CUDA
using namespace dlib::tt;
print_spinner();
const long n = 5;
const long k = 5;
const long nr = 3;
const long nc = 3;
const float alpha = 1.0f;
resizable_tensor src(n, k, nr, nc);
tt::tensor_rand rnd;
rnd.fill_uniform(src);
resizable_tensor dest_cuda, dest_cpu;
dest_cuda.copy_size(src);
dest_cpu.copy_size(src);
// initialize to different values in order to make sure the output is actually changed
dest_cuda = 1;
dest_cpu = 2;
cuda::elu(dest_cuda, src, alpha);
cpu::elu(dest_cpu, src, alpha);
auto error = max(abs(mat(dest_cuda) - mat(dest_cpu)));
DLIB_TEST_MSG(error < 1e-7, "error: " << error);
// test gradients
resizable_tensor grad_cuda, grad_cpu, grad_input;
grad_cuda.copy_size(src);
grad_cpu.copy_size(src);
grad_input.copy_size(src);
rnd.fill_uniform(grad_input);
grad_cuda = 0;
grad_cpu = 0;
cuda::elu_gradient(grad_cuda, dest_cuda, grad_input, alpha);
cpu::elu_gradient(grad_cpu, dest_cpu, grad_input, alpha);
error = max(abs(mat(grad_cuda) - mat(grad_cpu)));
DLIB_TEST_MSG(error < 1e-7, "error: " << error);
#endif // DLIB_USE_CUDA #endif // DLIB_USE_CUDA
} }
...@@ -2002,6 +2079,18 @@ namespace ...@@ -2002,6 +2079,18 @@ namespace
auto res = test_layer(l); auto res = test_layer(l);
DLIB_TEST_MSG(res, res); DLIB_TEST_MSG(res, res);
} }
{
print_spinner();
clipped_relu_ l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
elu_ l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{ {
print_spinner(); print_spinner();
gelu_ l; gelu_ l;
...@@ -4138,6 +4227,8 @@ namespace ...@@ -4138,6 +4227,8 @@ namespace
test_sigmoid(); test_sigmoid();
test_mish(); test_mish();
test_leaky_relu(); test_leaky_relu();
test_clipped_relu();
test_elu();
test_gelu(); test_gelu();
test_batch_normalize(); test_batch_normalize();
test_batch_normalize_conv(); test_batch_normalize_conv();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment