Unverified Commit a1f15837 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Do not use sqrt_2 in device code (fixes #2208) (#2210)

* do not use sqrt_2 in device code

* use CUDART_SQRT_2PI

* better sort includes
parent 3ba004f8
...@@ -1711,7 +1711,7 @@ namespace dlib ...@@ -1711,7 +1711,7 @@ namespace dlib
const tensor& gradient_input const tensor& gradient_input
) )
{ {
const float beta = 1.0f / std::sqrt(pi) / sqrt_2; const float beta = 1.0f / std::sqrt(2.0f * pi);
const auto compute_gradient = [beta](float x) const auto compute_gradient = [beta](float x)
{ {
const float cdf = 0.5f*(1.0f + std::erf(x/sqrt_2)); const float cdf = 0.5f*(1.0f + std::erf(x/sqrt_2));
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cuda_dlib.h" #include "cuda_dlib.h"
#include "cudnn_dlibapi.h" #include "cudnn_dlibapi.h"
#include <math_constants.h>
namespace dlib namespace dlib
...@@ -1501,7 +1502,7 @@ namespace dlib ...@@ -1501,7 +1502,7 @@ namespace dlib
__device__ float gelu_compute_gradient(float x) __device__ float gelu_compute_gradient(float x)
{ {
const float beta = 1.0f / std::sqrt(pi) / sqrt_2; const float beta = 1.0f / CUDART_SQRT_2PI;
const float cdf = normcdf(x); const float cdf = normcdf(x);
const float pdf = beta*std::exp(-0.5f*x*x); const float pdf = beta*std::exp(-0.5f*x*x);
return cdf + x * pdf; return cdf + x * pdf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment