Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
d1d96e38
Unverified
Commit
d1d96e38
authored
Apr 01, 2020
by
Adrià Arrufat
Committed by
GitHub
Mar 31, 2020
Browse files
remove branch from cuda kernel (#2045)
* remove branch from cuda kernel * promote lambda to a global function
parent
57bb5eb5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
24 deletions
+27
-24
dlib/cuda/cuda_dlib.cu
dlib/cuda/cuda_dlib.cu
+27
-24
No files found.
dlib/cuda/cuda_dlib.cu
View file @
d1d96e38
...
@@ -1405,7 +1405,7 @@ namespace dlib
...
@@ -1405,7 +1405,7 @@ namespace dlib
)
)
{
{
float
*
out
=
grad
.
device
();
float
*
out
=
grad
.
device
();
const
float
*
gi
=
gradient_input
.
device
();
const
float
*
gi
=
gradient_input
.
device
();
if
(
out
==
gi
)
if
(
out
==
gi
)
{
{
launch_kernel
(
_cuda_leaky_relu_gradient_inplace
,
max_jobs
(
grad
.
size
()),
launch_kernel
(
_cuda_leaky_relu_gradient_inplace
,
max_jobs
(
grad
.
size
()),
...
@@ -1440,31 +1440,29 @@ namespace dlib
...
@@ -1440,31 +1440,29 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
__
global__
void
_cuda_mish_gradient
(
float
*
out
,
const
float
*
s
,
const
float
*
gi
,
size_t
n
)
__
device__
float
mish_compute_gradient
(
float
x
)
{
{
const
auto
calculate_gradient
=
[](
float
x
)
if
(
x
>=
8
)
{
return
1.
f
;
if
(
x
>=
8
)
if
(
x
<=
-
8
)
return
1.
f
;
return
0.
f
;
if
(
x
<=
-
8
)
return
0.
f
;
const
auto
e
=
std
::
exp
(
x
);
const
auto
e
=
std
::
exp
(
x
);
const
auto
delta
=
2
*
e
+
e
*
e
+
2
;
const
auto
delta
=
2
*
e
+
e
*
e
+
2
;
const
auto
omega
=
4
*
(
x
+
1
)
+
4
*
e
*
e
+
e
*
e
*
e
+
e
*
(
4
*
x
+
6
);
const
auto
omega
=
4
*
(
x
+
1
)
+
4
*
e
*
e
+
e
*
e
*
e
+
e
*
(
4
*
x
+
6
);
return
e
*
omega
/
(
delta
*
delta
);
return
e
*
omega
/
(
delta
*
delta
);
};
}
if
(
out
==
gi
)
__global__
void
_cuda_mish_gradient_inplace
(
float
*
out
,
const
float
*
s
,
const
float
*
gi
,
size_t
n
)
{
{
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
out
[
i
]
=
gi
[
i
]
*
calcula
te_gradient
(
s
[
i
]);
out
[
i
]
=
gi
[
i
]
*
mish_compu
te_gradient
(
s
[
i
]);
}
}
else
{
__global__
void
_cuda_mish_gradient
(
float
*
out
,
const
float
*
s
,
const
float
*
gi
,
size_t
n
)
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
{
out
[
i
]
+=
gi
[
i
]
*
calculat
e_
g
ra
dient
(
s
[
i
]);
for
(
auto
i
:
grid_strid
e_ra
nge
(
0
,
n
))
}
out
[
i
]
+=
gi
[
i
]
*
mish_compute_gradient
(
s
[
i
]);
}
}
void
mish_gradient
(
void
mish_gradient
(
...
@@ -1473,7 +1471,12 @@ namespace dlib
...
@@ -1473,7 +1471,12 @@ namespace dlib
const
tensor
&
gradient_input
const
tensor
&
gradient_input
)
)
{
{
launch_kernel
(
_cuda_mish_gradient
,
max_jobs
(
grad
.
size
()),
grad
.
device
(),
src
.
device
(),
gradient_input
.
device
(),
grad
.
size
());
float
*
out
=
grad
.
device
();
const
float
*
gi
=
gradient_input
.
device
();
if
(
out
==
gi
)
launch_kernel
(
_cuda_mish_gradient_inplace
,
max_jobs
(
grad
.
size
()),
out
,
src
.
device
(),
gi
,
grad
.
size
());
else
launch_kernel
(
_cuda_mish_gradient
,
max_jobs
(
grad
.
size
()),
out
,
src
.
device
(),
gi
,
grad
.
size
());
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment