Unverified Commit d1d96e38 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

remove branch from cuda kernel (#2045)

* remove branch from cuda kernel

* promote lambda to a global function
parent 57bb5eb5
...@@ -1405,7 +1405,7 @@ namespace dlib ...@@ -1405,7 +1405,7 @@ namespace dlib
) )
{ {
float* out = grad.device(); float* out = grad.device();
const float *gi = gradient_input.device(); const float* gi = gradient_input.device();
if (out == gi) if (out == gi)
{ {
launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()), launch_kernel(_cuda_leaky_relu_gradient_inplace, max_jobs(grad.size()),
...@@ -1440,31 +1440,29 @@ namespace dlib ...@@ -1440,31 +1440,29 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n) __device__ float mish_compute_gradient(float x)
{ {
const auto calculate_gradient = [](float x) if (x >= 8)
{ return 1.f;
if (x >= 8) if (x <= -8)
return 1.f; return 0.f;
if (x <= -8)
return 0.f;
const auto e = std::exp(x); const auto e = std::exp(x);
const auto delta = 2*e + e*e + 2; const auto delta = 2*e + e*e + 2;
const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6); const auto omega = 4*(x + 1) + 4*e*e + e*e*e + e*(4*x + 6);
return e*omega/(delta*delta); return e*omega/(delta*delta);
}; }
if (out == gi) __global__ void _cuda_mish_gradient_inplace(float* out, const float* s, const float* gi, size_t n)
{ {
for (auto i : grid_stride_range(0, n)) for (auto i : grid_stride_range(0, n))
out[i] = gi[i]*calculate_gradient(s[i]); out[i] = gi[i]*mish_compute_gradient(s[i]);
} }
else
{ __global__ void _cuda_mish_gradient(float* out, const float* s, const float* gi, size_t n)
for (auto i : grid_stride_range(0, n)) {
out[i] += gi[i]*calculate_gradient(s[i]); for (auto i : grid_stride_range(0, n))
} out[i] += gi[i]*mish_compute_gradient(s[i]);
} }
void mish_gradient ( void mish_gradient (
...@@ -1473,7 +1471,12 @@ namespace dlib ...@@ -1473,7 +1471,12 @@ namespace dlib
const tensor& gradient_input const tensor& gradient_input
) )
{ {
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), grad.device(), src.device(), gradient_input.device(), grad.size()); float* out = grad.device();
const float* gi = gradient_input.device();
if (out == gi)
launch_kernel(_cuda_mish_gradient_inplace, max_jobs(grad.size()), out, src.device(), gi, grad.size());
else
launch_kernel(_cuda_mish_gradient, max_jobs(grad.size()), out, src.device(), gi, grad.size());
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment