Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
8fa65eb7
Commit
8fa65eb7
authored
Apr 07, 2022
by
Adrià Arrufat
Committed by
Davis E. King
Apr 12, 2022
Browse files
Add SmeLU activation
parent
66f9b2b5
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
406 additions
and
9 deletions
+406
-9
dlib/cuda/cpu_dlib.cpp
dlib/cuda/cpu_dlib.cpp
+55
-0
dlib/cuda/cpu_dlib.h
dlib/cuda/cpu_dlib.h
+15
-0
dlib/cuda/cuda_dlib.cu
dlib/cuda/cuda_dlib.cu
+74
-1
dlib/cuda/cuda_dlib.h
dlib/cuda/cuda_dlib.h
+15
-0
dlib/cuda/tensor_tools.cpp
dlib/cuda/tensor_tools.cpp
+30
-0
dlib/cuda/tensor_tools.h
dlib/cuda/tensor_tools.h
+43
-0
dlib/dnn/layers.h
dlib/dnn/layers.h
+80
-0
dlib/dnn/layers_abstract.h
dlib/dnn/layers_abstract.h
+53
-0
dlib/test/dnn.cpp
dlib/test/dnn.cpp
+41
-8
No files found.
dlib/cuda/cpu_dlib.cpp
View file @
8fa65eb7
...
@@ -1998,6 +1998,61 @@ namespace dlib
...
@@ -1998,6 +1998,61 @@ namespace dlib
}
}
}
}
void
smelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
float
beta
)
{
const
float
*
s
=
src
.
host
();
float
*
d
=
dest
.
host
();
for
(
size_t
i
=
0
;
i
<
dest
.
size
();
++
i
)
{
if
(
s
[
i
]
>=
beta
)
d
[
i
]
=
s
[
i
];
else
if
(
s
[
i
]
<=
-
beta
)
d
[
i
]
=
0
;
else
d
[
i
]
=
(
s
[
i
]
+
beta
)
*
(
s
[
i
]
+
beta
)
/
(
4
*
beta
);
}
}
void
smelu_gradient
(
tensor
&
grad
,
const
tensor
&
dest
,
const
tensor
&
gradient_input
,
const
float
beta
)
{
const
float
*
gi
=
gradient_input
.
host
();
const
float
*
in
=
dest
.
host
();
float
*
out
=
grad
.
host
();
if
(
is_same_object
(
grad
,
gradient_input
))
{
for
(
size_t
i
=
0
;
i
<
dest
.
size
();
++
i
)
{
if
(
in
[
i
]
>=
beta
)
out
[
i
]
=
gi
[
i
];
else
if
(
in
[
i
]
==
0
)
out
[
i
]
=
0
;
else
out
[
i
]
=
std
::
sqrt
(
beta
*
in
[
i
])
/
beta
*
gi
[
i
];
}
}
else
{
for
(
size_t
i
=
0
;
i
<
dest
.
size
();
++
i
)
{
if
(
in
[
i
]
>=
beta
)
out
[
i
]
+=
gi
[
i
];
else
if
(
in
[
i
]
==
0
)
continue
;
else
out
[
i
]
+=
std
::
sqrt
(
beta
*
in
[
i
])
/
beta
*
gi
[
i
];
}
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
void
resize_bilinear
(
...
...
dlib/cuda/cpu_dlib.h
View file @
8fa65eb7
...
@@ -421,6 +421,21 @@ namespace dlib
...
@@ -421,6 +421,21 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
smelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
float
beta
);
void
smelu
(
tensor
&
grad
,
const
tensor
&
dest
,
const
tensor
&
gradient_input
,
const
float
beta
);
// ------------------------------------------------------------------------------------
void
resize_bilinear
(
void
resize_bilinear
(
tensor
&
dest
,
tensor
&
dest
,
long
long
dest_row_stride
,
long
long
dest_row_stride
,
...
...
dlib/cuda/cuda_dlib.cu
View file @
8fa65eb7
...
@@ -1366,7 +1366,7 @@ namespace dlib
...
@@ -1366,7 +1366,7 @@ namespace dlib
void
leaky_relu
(
void
leaky_relu
(
tensor
&
dest
,
tensor
&
dest
,
const
tensor
&
src
,
const
tensor
&
src
,
const
float
alpha
const
float
alpha
)
)
{
{
...
@@ -1657,6 +1657,79 @@ namespace dlib
...
@@ -1657,6 +1657,79 @@ namespace dlib
launch_kernel
(
_cuda_gelu_gradient
,
max_jobs
(
grad
.
size
()),
out
,
src
.
device
(),
gi
,
grad
.
size
());
launch_kernel
(
_cuda_gelu_gradient
,
max_jobs
(
grad
.
size
()),
out
,
src
.
device
(),
gi
,
grad
.
size
());
}
}
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_smelu
(
const
float
*
s
,
float
*
d
,
size_t
n
,
const
float
beta
)
{
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
{
if
(
s
[
i
]
>=
beta
)
d
[
i
]
=
s
[
i
];
else
if
(
s
[
i
]
<=
-
beta
)
d
[
i
]
=
0
;
else
d
[
i
]
=
(
s
[
i
]
+
beta
)
*
(
s
[
i
]
+
beta
)
/
(
4
*
beta
);
}
}
void
smelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
float
beta
)
{
launch_kernel
(
_cuda_smelu
,
max_jobs
(
dest
.
size
()),
src
.
device
(),
dest
.
device
(),
src
.
size
(),
beta
);
}
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_smelu_gradient_inplace
(
float
*
out
,
const
float
*
s
,
const
float
*
gi
,
size_t
n
,
const
float
beta
)
{
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
{
if
(
s
[
i
]
>=
beta
)
out
[
i
]
=
gi
[
i
];
else
if
(
s
[
i
]
==
0
)
out
[
i
]
=
0
;
else
out
[
i
]
=
std
::
sqrt
(
beta
*
s
[
i
])
/
beta
*
gi
[
i
];
}
}
__global__
void
_cuda_smelu_gradient
(
float
*
out
,
const
float
*
s
,
const
float
*
gi
,
size_t
n
,
const
float
beta
)
{
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
{
if
(
s
[
i
]
>=
beta
)
out
[
i
]
+=
gi
[
i
];
else
if
(
s
[
i
]
==
0
)
continue
;
else
out
[
i
]
+=
std
::
sqrt
(
beta
*
s
[
i
])
/
beta
*
gi
[
i
];
}
}
void
smelu_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
,
const
float
beta
)
{
float
*
out
=
grad
.
device
();
const
float
*
gi
=
gradient_input
.
device
();
if
(
out
==
gi
)
{
launch_kernel
(
_cuda_smelu_gradient_inplace
,
max_jobs
(
grad
.
size
()),
out
,
src
.
device
(),
gi
,
grad
.
size
(),
beta
);
}
else
{
launch_kernel
(
_cuda_smelu_gradient
,
max_jobs
(
grad
.
size
()),
out
,
src
.
device
(),
gi
,
grad
.
size
(),
beta
);
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_resize_bilinear
(
size_t
dsize
,
size_t
dchan_size
,
size_t
dnc
,
float
*
d
,
__global__
void
_cuda_resize_bilinear
(
size_t
dsize
,
size_t
dchan_size
,
size_t
dnc
,
float
*
d
,
...
...
dlib/cuda/cuda_dlib.h
View file @
8fa65eb7
...
@@ -465,6 +465,21 @@ namespace dlib
...
@@ -465,6 +465,21 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
smelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
float
beta
);
void
smelu_gradient
(
tensor
&
grad
,
const
tensor
&
dest
,
const
tensor
&
gradient_input
,
const
float
beta
);
// ------------------------------------------------------------------------------------
void
resize_bilinear
(
void
resize_bilinear
(
tensor
&
dest
,
tensor
&
dest
,
long
long
dest_row_stride
,
long
long
dest_row_stride
,
...
...
dlib/cuda/tensor_tools.cpp
View file @
8fa65eb7
...
@@ -1084,6 +1084,36 @@ namespace dlib { namespace tt
...
@@ -1084,6 +1084,36 @@ namespace dlib { namespace tt
#endif
#endif
}
}
// ----------------------------------------------------------------------------------------
void
smelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
float
beta
)
{
DLIB_CASSERT
(
beta
>
0
);
#ifdef DLIB_USE_CUDA
cuda
::
smelu
(
dest
,
src
,
beta
);
#else
cpu
::
smelu
(
dest
,
src
,
beta
);
#endif
}
void
smelu_gradient
(
tensor
&
grad
,
const
tensor
&
dest
,
const
tensor
&
gradient_input
,
const
float
beta
)
{
DLIB_CASSERT
(
beta
>
0
);
#ifdef DLIB_USE_CUDA
cuda
::
smelu_gradient
(
grad
,
dest
,
gradient_input
,
beta
);
#else
cpu
::
smelu_gradient
(
grad
,
dest
,
gradient_input
,
beta
);
#endif
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
void
resize_bilinear
(
...
...
dlib/cuda/tensor_tools.h
View file @
8fa65eb7
...
@@ -1747,6 +1747,49 @@ namespace dlib { namespace tt
...
@@ -1747,6 +1747,49 @@ namespace dlib { namespace tt
is_same_object(grad, gradient_input)==true
is_same_object(grad, gradient_input)==true
!*/
!*/
// ----------------------------------------------------------------------------------------
void
smelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
float
beta
);
/*!
requires
- have_same_dimensions(dest, src) == true
- beta > 0
ensures
- for all valid i:
- if (src.host()[i] > beta) then
- #dest.host()[i] == src.host()[i]
- else if (src.host()[i] < -beta) then
- #dest.host()[i] == 0
- else
- #dest.host()[i] == std::pow(src.host()[i] + beta), 2) / (4 * beta)
!*/
void
smelu_gradient
(
tensor
&
grad
,
const
tensor
&
dest
,
const
tensor
&
gradient_input
,
const
float
beta
);
/*!
requires
- have_same_dimensions(dest,gradient_input) == true
- have_same_dimensions(dest,grad) == true
- beta > 0
ensures
- Recalling that dest is the output of smelu(dest,SRC) for some SRC tensor,
let f(SRC) == dot(gradient_input,dest). Then this function computes the
gradient of f() with respect to SRC and stores it to grad. Moreover, if
is_same_object(grad,gradient_input)==true then the output is assigned to
grad, replacing its previous contents. Otherwise the output is added to
grad.
- This function supports in-place operation, i.e. having
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
void
resize_bilinear
(
...
...
dlib/dnn/layers.h
View file @
8fa65eb7
...
@@ -3113,6 +3113,7 @@ namespace dlib
...
@@ -3113,6 +3113,7 @@ namespace dlib
using
prelu
=
add_layer
<
prelu_
,
SUBNET
>
;
using
prelu
=
add_layer
<
prelu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class
leaky_relu_
class
leaky_relu_
{
{
public:
public:
...
@@ -3629,6 +3630,85 @@ namespace dlib
...
@@ -3629,6 +3630,85 @@ namespace dlib
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
using
gelu
=
add_layer
<
gelu_
,
SUBNET
>
;
using
gelu
=
add_layer
<
gelu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
smelu_
{
public:
explicit
smelu_
(
float
beta_
=
1
)
:
beta
(
beta_
)
{
}
float
get_beta
(
)
const
{
return
beta
;
}
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
/*sub*/
)
{
}
void
forward_inplace
(
const
tensor
&
input
,
tensor
&
output
)
{
tt
::
smelu
(
output
,
input
,
beta
);
}
void
backward_inplace
(
const
tensor
&
computed_output
,
const
tensor
&
gradient_input
,
tensor
&
data_grad
,
tensor
&
)
{
tt
::
smelu_gradient
(
data_grad
,
computed_output
,
gradient_input
,
beta
);
}
inline
dpoint
map_input_to_output
(
const
dpoint
&
p
)
const
{
return
p
;
}
inline
dpoint
map_output_to_input
(
const
dpoint
&
p
)
const
{
return
p
;
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
smelu_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"smelu_"
,
out
);
serialize
(
item
.
beta
,
out
);
}
friend
void
deserialize
(
smelu_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"smelu_"
)
throw
serialization_error
(
"Unexpected version '"
+
version
+
"' found while deserializing dlib::smelu_."
);
deserialize
(
item
.
beta
,
in
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
smelu_
&
item
)
{
out
<<
"smelu
\t
("
<<
"beta="
<<
item
.
beta
<<
")"
;
return
out
;
}
friend
void
to_xml
(
const
smelu_
&
item
,
std
::
ostream
&
out
)
{
out
<<
"<smelu beta='"
<<
item
.
beta
<<
"'>
\n
"
;
out
<<
"<smelu/>
\n
"
;
}
private:
resizable_tensor
params
;
float
beta
;
};
template
<
typename
SUBNET
>
using
smelu
=
add_layer
<
smelu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class
softmax_
class
softmax_
...
...
dlib/dnn/layers_abstract.h
View file @
8fa65eb7
...
@@ -2638,6 +2638,59 @@ namespace dlib
...
@@ -2638,6 +2638,59 @@ namespace dlib
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
using
gelu
=
add_layer
<
gelu_
,
SUBNET
>
;
using
gelu
=
add_layer
<
gelu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
smelu_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. In particular, it defines a smooth rectified linear
layer. Therefore, it passes its inputs through the function f(x):
- if (x > beta) 1
- if (x < -beta) 0
- else std::pow(x + beta, 2) / (4 * beta)
where f() is applied pointwise across the input tensor and beta is a
non-learned scalar.
This is the layer type introduced in the paper:
"Smooth activations and reproducibility in deep networks" by
Gil I. Shamir, Dong Lin, Lorenzo Coviello (https://arxiv.org/abs/2010.09931)
!*/
public:
explicit
smelu_
(
float
beta
=
1
);
/*!
ensures
- the beta parameter will be initialized with the beta value
!*/
float
get_beta
(
)
const
;
/*!
ensures
- returns the beta parameter of the smelu
!*/
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
);
void
forward_inplace
(
const
tensor
&
input
,
tensor
&
output
);
void
backward_inplace
(
const
tensor
&
computed_output
,
const
tensor
&
gradient_input
,
tensor
&
data_grad
,
tensor
&
params_grad
);
dpoint
map_input_to_output
(
dpoint
p
)
const
;
dpoint
map_output_to_input
(
dpoint
p
)
const
;
const
tensor
&
get_layer_params
()
const
;
tensor
&
get_layer_params
();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_
interface. Note that this layer doesn't have any parameters, so the tensor
returned by get_layer_params() is always empty.
!*/
};
template
<
typename
SUBNET
>
using
smelu
=
add_layer
<
prelu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class
softmax_
class
softmax_
...
...
dlib/test/dnn.cpp
View file @
8fa65eb7
...
@@ -223,13 +223,13 @@ namespace
...
@@ -223,13 +223,13 @@ namespace
// make sure that cuda::mish and cpu::mish return the same results
// make sure that cuda::mish and cpu::mish return the same results
using
namespace
dlib
::
tt
;
using
namespace
dlib
::
tt
;
print_spinner
();
print_spinner
();
const
long
n
=
5
;
const
long
n
=
4
;
const
long
k
=
5
;
const
long
k
=
5
;
const
long
nr
=
3
;
const
long
nr
=
3
;
const
long
nc
=
3
;
const
long
nc
=
3
;
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
tt
::
tensor_rand
rnd
;
tt
::
tensor_rand
rnd
;
rnd
.
fill_
uniform
(
src
);
rnd
.
fill_
gaussian
(
src
);
resizable_tensor
dest1
,
dest2
;
resizable_tensor
dest1
,
dest2
;
dest1
.
copy_size
(
src
);
dest1
.
copy_size
(
src
);
...
@@ -239,7 +239,7 @@ namespace
...
@@ -239,7 +239,7 @@ namespace
dest2
=
2
;
dest2
=
2
;
cuda
::
mish
(
dest1
,
src
);
cuda
::
mish
(
dest1
,
src
);
cpu
::
mish
(
dest2
,
src
);
cpu
::
mish
(
dest2
,
src
);
DLIB_TEST_MSG
(
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
)))
<
1e-
7
,
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
))));
DLIB_TEST_MSG
(
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
)))
<
1e-
6
,
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
))));
#endif // DLIB_USE_CUDA
#endif // DLIB_USE_CUDA
}
}
...
@@ -248,14 +248,14 @@ namespace
...
@@ -248,14 +248,14 @@ namespace
#ifdef DLIB_USE_CUDA
#ifdef DLIB_USE_CUDA
using
namespace
dlib
::
tt
;
using
namespace
dlib
::
tt
;
print_spinner
();
print_spinner
();
const
long
n
=
5
;
const
long
n
=
4
;
const
long
k
=
5
;
const
long
k
=
5
;
const
long
nr
=
3
;
const
long
nr
=
3
;
const
long
nc
=
3
;
const
long
nc
=
3
;
const
float
alpha
=
0.01
;
const
float
alpha
=
0.01
;
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
tt
::
tensor_rand
rnd
;
tt
::
tensor_rand
rnd
;
rnd
.
fill_
uniform
(
src
);
rnd
.
fill_
gaussian
(
src
);
resizable_tensor
dest_cuda
,
dest_cpu
;
resizable_tensor
dest_cuda
,
dest_cpu
;
dest_cuda
.
copy_size
(
src
);
dest_cuda
.
copy_size
(
src
);
dest_cpu
.
copy_size
(
src
);
dest_cpu
.
copy_size
(
src
);
...
@@ -352,13 +352,13 @@ namespace
...
@@ -352,13 +352,13 @@ namespace
// make sure that cuda::gelu and cpu::gelu return the same results
// make sure that cuda::gelu and cpu::gelu return the same results
using
namespace
dlib
::
tt
;
using
namespace
dlib
::
tt
;
print_spinner
();
print_spinner
();
const
long
n
=
5
;
const
long
n
=
4
;
const
long
k
=
5
;
const
long
k
=
5
;
const
long
nr
=
3
;
const
long
nr
=
3
;
const
long
nc
=
3
;
const
long
nc
=
3
;
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
tt
::
tensor_rand
rnd
;
tt
::
tensor_rand
rnd
;
rnd
.
fill_
uniform
(
src
);
rnd
.
fill_
gaussian
(
src
);
resizable_tensor
dest1
,
dest2
;
resizable_tensor
dest1
,
dest2
;
dest1
.
copy_size
(
src
);
dest1
.
copy_size
(
src
);
...
@@ -368,7 +368,33 @@ namespace
...
@@ -368,7 +368,33 @@ namespace
dest2
=
2
;
dest2
=
2
;
cuda
::
gelu
(
dest1
,
src
);
cuda
::
gelu
(
dest1
,
src
);
cpu
::
gelu
(
dest2
,
src
);
cpu
::
gelu
(
dest2
,
src
);
DLIB_TEST_MSG
(
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
)))
<
1e-7
,
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
))));
DLIB_TEST_MSG
(
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
)))
<
1e-6
,
max
(
abs
(
mat
(
dest1
)
-
mat
(
dest2
))));
#endif // DLIB_USE_CUDA
}
void
test_smelu
()
{
#ifdef DLIB_USE_CUDA
using
namespace
dlib
::
tt
;
print_spinner
();
const
long
n
=
4
;
const
long
k
=
5
;
const
long
nr
=
3
;
const
long
nc
=
3
;
const
float
beta
=
1
;
resizable_tensor
src
(
n
,
k
,
nr
,
nc
);
tt
::
tensor_rand
rnd
;
rnd
.
fill_gaussian
(
src
);
resizable_tensor
dest_cuda
,
dest_cpu
;
dest_cuda
.
copy_size
(
src
);
dest_cpu
.
copy_size
(
src
);
// initialize to different values in order to make sure the output is actually changed
dest_cuda
=
1
;
dest_cpu
=
2
;
cuda
::
smelu
(
dest_cuda
,
src
,
beta
);
cpu
::
smelu
(
dest_cpu
,
src
,
beta
);
DLIB_TEST_MSG
(
max
(
abs
(
mat
(
dest_cuda
)
-
mat
(
dest_cpu
)))
<
1e-7
,
max
(
abs
(
mat
(
dest_cuda
)
-
mat
(
dest_cpu
))));
#endif // DLIB_USE_CUDA
#endif // DLIB_USE_CUDA
}
}
...
@@ -2103,6 +2129,12 @@ namespace
...
@@ -2103,6 +2129,12 @@ namespace
auto
res
=
test_layer
(
l
);
auto
res
=
test_layer
(
l
);
DLIB_TEST_MSG
(
res
,
res
);
DLIB_TEST_MSG
(
res
,
res
);
}
}
{
print_spinner
();
smelu_
l
;
auto
res
=
test_layer
(
l
);
DLIB_TEST_MSG
(
res
,
res
);
}
{
{
print_spinner
();
print_spinner
();
softmax_
l
;
softmax_
l
;
...
@@ -4286,6 +4318,7 @@ namespace
...
@@ -4286,6 +4318,7 @@ namespace
test_clipped_relu
();
test_clipped_relu
();
test_elu
();
test_elu
();
test_gelu
();
test_gelu
();
test_smelu
();
test_batch_normalize
();
test_batch_normalize
();
test_batch_normalize_conv
();
test_batch_normalize_conv
();
test_layer_normalize
();
test_layer_normalize
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment