Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
edff12d2
Commit
edff12d2
authored
Dec 02, 2019
by
thebhatman
Committed by
Davis King
Jan 14, 2020
Browse files
Adding Mish activation function
parent
cd5f0b05
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
337 additions
and
1 deletion
+337
-1
dlib/cuda/cpu_dlib.cpp
dlib/cuda/cpu_dlib.cpp
+52
-0
dlib/cuda/cpu_dlib.h
dlib/cuda/cpu_dlib.h
+13
-0
dlib/cuda/cuda_dlib.cu
dlib/cuda/cuda_dlib.cu
+57
-0
dlib/cuda/cuda_dlib.h
dlib/cuda/cuda_dlib.h
+12
-0
dlib/cuda/tensor_tools.cpp
dlib/cuda/tensor_tools.cpp
+27
-0
dlib/cuda/tensor_tools.h
dlib/cuda/tensor_tools.h
+34
-0
dlib/dnn/layers.h
dlib/dnn/layers.h
+73
-0
dlib/dnn/layers_abstract.h
dlib/dnn/layers_abstract.h
+35
-0
dlib/test/dnn.cpp
dlib/test/dnn.cpp
+34
-1
No files found.
dlib/cuda/cpu_dlib.cpp
View file @
edff12d2
...
@@ -1467,6 +1467,58 @@ namespace dlib
...
@@ -1467,6 +1467,58 @@ namespace dlib
}
}
}
}
// ------------------------------------------------------------------------------------
void
mish
(
tensor
&
dest
,
const
tensor
&
src
)
{
const
auto
d
=
dest
.
host_write_only
();
const
auto
s
=
src
.
host
();
for
(
size_t
i
=
0
;
i
<
src
.
size
();
++
i
)
{
const
auto
e
=
std
::
exp
(
s
[
i
]);
const
auto
delta
=
2
*
e
+
e
*
e
+
2
;
d
[
i
]
=
s
[
i
]
-
2
*
s
[
i
]
/
delta
;
}
}
void
mish_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
)
{
const
auto
g
=
grad
.
host
();
const
auto
s
=
src
.
host
();
const
auto
in
=
gradient_input
.
host
();
const
auto
calculate_gradient
=
[](
float
x
)
{
if
(
x
>=
8
)
return
1.
f
;
if
(
x
<=
-
8
)
return
0.
f
;
const
auto
e
=
std
::
exp
(
x
);
const
auto
delta
=
2
*
e
+
e
*
e
+
2
;
const
auto
omega
=
4
*
(
x
+
1
)
+
4
*
e
*
e
+
e
*
e
*
e
+
e
*
(
4
*
x
+
6
);
return
e
*
omega
/
(
delta
*
delta
);
};
if
(
is_same_object
(
gradient_input
,
grad
))
{
for
(
size_t
i
=
0
;
i
<
src
.
size
();
++
i
)
g
[
i
]
=
in
[
i
]
*
calculate_gradient
(
s
[
i
]);
}
else
{
for
(
size_t
i
=
0
;
i
<
src
.
size
();
++
i
)
g
[
i
]
+=
in
[
i
]
*
calculate_gradient
(
s
[
i
]);
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
relu
(
void
relu
(
...
...
dlib/cuda/cpu_dlib.h
View file @
edff12d2
...
@@ -281,6 +281,19 @@ namespace dlib
...
@@ -281,6 +281,19 @@ namespace dlib
const
tensor
&
gradient_input
const
tensor
&
gradient_input
);
);
// ------------------------------------------------------------------------------------
void
mish
(
tensor
&
dest
,
const
tensor
&
src
);
void
mish_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
relu
(
void
relu
(
...
...
dlib/cuda/cuda_dlib.cu
View file @
edff12d2
...
@@ -1351,6 +1351,63 @@ namespace dlib
...
@@ -1351,6 +1351,63 @@ namespace dlib
param
.
device
(),
params_grad
.
device
());
param
.
device
(),
params_grad
.
device
());
}
}
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_mish
(
const
float
*
s
,
float
*
d
,
size_t
n
)
{
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
{
const
auto
e
=
std
::
exp
(
s
[
i
]);
const
auto
delta
=
2
*
e
+
e
*
e
+
2
;
d
[
i
]
=
s
[
i
]
-
2
*
s
[
i
]
/
delta
;
}
}
void
mish
(
tensor
&
dest
,
const
tensor
&
src
)
{
launch_kernel
(
_cuda_mish
,
max_jobs
(
dest
.
size
()),
src
.
device
(),
dest
.
device
(),
src
.
size
());
}
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_mish_gradient
(
float
*
out
,
const
float
*
s
,
const
float
*
gi
,
size_t
n
)
{
const
auto
calculate_gradient
=
[](
float
x
)
{
if
(
x
>=
8
)
return
1.
f
;
if
(
x
<=
-
8
)
return
0.
f
;
const
auto
e
=
std
::
exp
(
x
);
const
auto
delta
=
2
*
e
+
e
*
e
+
2
;
const
auto
omega
=
4
*
(
x
+
1
)
+
4
*
e
*
e
+
e
*
e
*
e
+
e
*
(
4
*
x
+
6
);
return
e
*
omega
/
(
delta
*
delta
);
};
if
(
out
==
gi
)
{
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
out
[
i
]
=
gi
[
i
]
*
calculate_gradient
(
s
[
i
]);
}
else
{
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
out
[
i
]
+=
gi
[
i
]
*
calculate_gradient
(
s
[
i
]);
}
}
void
mish_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
)
{
launch_kernel
(
_cuda_mish_gradient
,
max_jobs
(
grad
.
size
()),
grad
.
device
(),
src
.
device
(),
gradient_input
.
device
(),
grad
.
size
());
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_resize_bilinear
(
size_t
dsize
,
size_t
dchan_size
,
size_t
dnc
,
float
*
d
,
__global__
void
_cuda_resize_bilinear
(
size_t
dsize
,
size_t
dchan_size
,
size_t
dnc
,
float
*
d
,
...
...
dlib/cuda/cuda_dlib.h
View file @
edff12d2
...
@@ -367,6 +367,18 @@ namespace dlib
...
@@ -367,6 +367,18 @@ namespace dlib
tensor
&
params_grad
tensor
&
params_grad
);
);
// ----------------------------------------------------------------------------------------
void
mish
(
tensor
&
dest
,
const
tensor
&
src
);
void
mish_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
);
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
dlib/cuda/tensor_tools.cpp
View file @
edff12d2
...
@@ -826,6 +826,33 @@ namespace dlib { namespace tt
...
@@ -826,6 +826,33 @@ namespace dlib { namespace tt
#endif
#endif
}
}
// ----------------------------------------------------------------------------------------
void
mish
(
tensor
&
dest
,
const
tensor
&
src
)
{
#ifdef DLIB_USE_CUDA
cuda
::
mish
(
dest
,
src
);
#else
cpu
::
mish
(
dest
,
src
);
#endif
}
void
mish_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
)
{
#ifdef DLIB_USE_CUDA
cuda
::
mish_gradient
(
grad
,
src
,
gradient_input
);
#else
cpu
::
mish_gradient
(
grad
,
src
,
gradient_input
);
#endif
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
relu
(
void
relu
(
...
...
dlib/cuda/tensor_tools.h
View file @
edff12d2
...
@@ -1330,6 +1330,40 @@ namespace dlib { namespace tt
...
@@ -1330,6 +1330,40 @@ namespace dlib { namespace tt
is_same_object(grad, gradient_input)==true
is_same_object(grad, gradient_input)==true
!*/
!*/
// ----------------------------------------------------------------------------------------
void
mish
(
tensor
&
dest
,
const
tensor
&
src
);
/*!
requires
- have_same_dimensions(dest, src) == true
ensures
- for all valid i:
- #dest.host()[i] == src.host()[i]*std::tanh(std::log(1+std::exp(src.host()[i])))
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void
mish_gradient
(
tensor
&
grad
,
const
tensor
&
dest
,
const
tensor
&
gradient_input
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
ensures
- This function computes the gradient of f() with respect to SRC and stores
it to grad. Moreover, if is_same_object(grad,gradient_input)==true then
the output is assigned to grad, replacing its previous contents.
Otherwise the output is added to grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
relu
(
void
relu
(
...
...
dlib/dnn/layers.h
View file @
edff12d2
...
@@ -2829,6 +2829,79 @@ namespace dlib
...
@@ -2829,6 +2829,79 @@ namespace dlib
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
using
sig
=
add_layer
<
sig_
,
SUBNET
>
;
using
sig
=
add_layer
<
sig_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
mish_
{
public:
mish_
()
{
}
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
/*sub*/
)
{
}
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
data_output
)
{
data_output
.
copy_size
(
sub
.
get_output
());
tt
::
mish
(
data_output
,
sub
.
get_output
());
}
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
)
{
tt
::
mish_gradient
(
sub
.
get_gradient_input
(),
sub
.
get_output
(),
gradient_input
);
}
inline
dpoint
map_input_to_output
(
const
dpoint
&
p
)
const
{
return
p
;
}
inline
dpoint
map_output_to_input
(
const
dpoint
&
p
)
const
{
return
p
;
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
mish_
&
,
std
::
ostream
&
out
)
{
serialize
(
"mish_"
,
out
);
}
friend
void
deserialize
(
mish_
&
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"mish_"
)
throw
serialization_error
(
"Unexpected version '"
+
version
+
"' found while deserializing dlib::mish_."
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
mish_
&
)
{
out
<<
"mish"
;
return
out
;
}
friend
void
to_xml
(
const
mish_
&
/*item*/
,
std
::
ostream
&
out
)
{
out
<<
"<mish/>
\n
"
;
}
private:
resizable_tensor
params
;
};
template
<
typename
SUBNET
>
using
mish
=
add_layer
<
mish_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class
htan_
class
htan_
...
...
dlib/dnn/layers_abstract.h
View file @
edff12d2
...
@@ -2125,6 +2125,41 @@ namespace dlib
...
@@ -2125,6 +2125,41 @@ namespace dlib
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
using
sig
=
add_layer
<
sig_
,
SUBNET
>
;
using
sig
=
add_layer
<
sig_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
mish_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. In particular, it defines a mish layer. Therefore, it
passes its inputs through the function
f(x)= x*tanh(log(1+exp(x)))
where f() is applied pointwise across the input tensor.
!*/
public:
mish_
(
);
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
);
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
data_output
);
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
);
dpoint
map_input_to_output
(
dpoint
p
)
const
;
dpoint
map_output_to_input
(
dpoint
p
)
const
;
const
tensor
&
get_layer_params
()
const
;
tensor
&
get_layer_params
();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
interface. Note that this layer doesn't have any parameters, so the tensor
returned by get_layer_params() is always empty.
!*/
};
template
<
typename
SUBNET
>
using
mish
=
add_layer
<
mish_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class
htan_
class
htan_
...
...
dlib/test/dnn.cpp
View file @
edff12d2
...
@@ -172,7 +172,7 @@ namespace
...
@@ -172,7 +172,7 @@ namespace
print_spinner
();
print_spinner
();
const
long
nr
=
3
;
const
long
nr
=
3
;
const
long
nc
=
3
;
const
long
nc
=
3
;
resizable_tensor
src
(
5
,
5
,
nr
,
n
r
),
dest
(
5
,
5
,
nr
,
nc
),
gradient_input
(
5
,
5
,
nr
,
nc
);
resizable_tensor
src
(
5
,
5
,
nr
,
n
c
),
dest
(
5
,
5
,
nr
,
nc
),
gradient_input
(
5
,
5
,
nr
,
nc
);
tt
::
tensor_rand
rnd
;
tt
::
tensor_rand
rnd
;
rnd
.
fill_uniform
(
src
);
rnd
.
fill_uniform
(
src
);
rnd
.
fill_uniform
(
dest
);
rnd
.
fill_uniform
(
dest
);
...
@@ -217,6 +217,32 @@ namespace
...
@@ -217,6 +217,32 @@ namespace
#endif
#endif
}
}
void
test_mish
()
{
#ifdef DLIB_USE_CUDA
// make sure that cuda::mish and cpu::mish return the same results
using
namespace
dlib
::
tt
;
print_spinner
();
const
long
n
=
5
;
const
long
k
=
5
;
const
long
nr
=
3
;
const
long
nc
=
3
;
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
tt
::
tensor_rand
rnd
;
rnd
.
fill_uniform
(
src
);
resizable_tensor
dest1
,
dest2
;
dest1
.
copy_size
(
src
);
dest2
.
copy_size
(
src
);
// initialize to different values in order to make sure the output is actually changed
dest1
=
1
;
dest2
=
2
;
cuda
::
mish
(
dest1
,
src
);
cpu
::
mish
(
dest2
,
src
);
DLIB_TEST_MSG
(
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
)))
<
1e-7
,
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
))));
#endif // DLIB_USE_CUDA
}
void
test_batch_normalize
()
void
test_batch_normalize
()
{
{
using
namespace
dlib
::
tt
;
using
namespace
dlib
::
tt
;
...
@@ -1832,6 +1858,12 @@ namespace
...
@@ -1832,6 +1858,12 @@ namespace
auto
res
=
test_layer
(
l
);
auto
res
=
test_layer
(
l
);
DLIB_TEST_MSG
(
res
,
res
);
DLIB_TEST_MSG
(
res
,
res
);
}
}
{
print_spinner
();
mish_
l
;
auto
res
=
test_layer
(
l
);
DLIB_TEST_MSG
(
res
,
res
);
}
{
{
print_spinner
();
print_spinner
();
htan_
l
;
htan_
l
;
...
@@ -3382,6 +3414,7 @@ namespace
...
@@ -3382,6 +3414,7 @@ namespace
test_softmax
();
test_softmax
();
test_softmax_all
();
test_softmax_all
();
test_sigmoid
();
test_sigmoid
();
test_mish
();
test_batch_normalize
();
test_batch_normalize
();
test_batch_normalize_conv
();
test_batch_normalize_conv
();
test_basic_tensor_ops
();
test_basic_tensor_ops
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment