Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
dcb5b46b
Commit
dcb5b46b
authored
Mar 31, 2016
by
Davis King
Browse files
Added prelu layer
parent
ebf7a89a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
331 additions
and
1 deletion
+331
-1
dlib/dnn/cpu_dlib.cpp
dlib/dnn/cpu_dlib.cpp
+48
-0
dlib/dnn/cpu_dlib.h
dlib/dnn/cpu_dlib.h
+16
-0
dlib/dnn/cuda_dlib.cu
dlib/dnn/cuda_dlib.cu
+63
-1
dlib/dnn/cuda_dlib.h
dlib/dnn/cuda_dlib.h
+16
-0
dlib/dnn/layers.h
dlib/dnn/layers.h
+67
-0
dlib/dnn/layers_abstract.h
dlib/dnn/layers_abstract.h
+49
-0
dlib/dnn/tensor_tools.cpp
dlib/dnn/tensor_tools.cpp
+30
-0
dlib/dnn/tensor_tools.h
dlib/dnn/tensor_tools.h
+42
-0
No files found.
dlib/dnn/cpu_dlib.cpp
View file @
dcb5b46b
...
...
@@ -1160,6 +1160,54 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
void
prelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
tensor
&
param
)
{
const
float
p
=
param
.
host
()[
0
];
const
float
*
s
=
src
.
host
();
float
*
d
=
dest
.
host
();
for
(
size_t
i
=
0
;
i
<
dest
.
size
();
++
i
)
{
if
(
s
[
i
]
>
0
)
d
[
i
]
=
s
[
i
];
else
d
[
i
]
=
p
*
s
[
i
];
}
}
void
prelu_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
,
const
tensor
&
param
,
tensor
&
params_grad
)
{
const
float
p
=
param
.
host
()[
0
];
const
float
*
gi
=
gradient_input
.
host
();
const
float
*
s
=
src
.
host
();
float
*
out
=
grad
.
host
();
float
pgrad
=
0
;
for
(
size_t
i
=
0
;
i
<
src
.
size
();
++
i
)
{
if
(
s
[
i
]
>
0
)
{
out
[
i
]
+=
gi
[
i
];
}
else
{
out
[
i
]
+=
p
*
gi
[
i
];
pgrad
+=
gi
[
i
]
*
s
[
i
];
}
}
params_grad
.
host
()[
0
]
=
pgrad
;
}
// ------------------------------------------------------------------------------------
void
tanh
(
...
...
dlib/dnn/cpu_dlib.h
View file @
dcb5b46b
...
...
@@ -235,6 +235,22 @@ namespace dlib
const
tensor
&
gradient_input
);
// ----------------------------------------------------------------------------------------
void
prelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
tensor
&
param
);
void
prelu_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
,
const
tensor
&
param
,
tensor
&
params_grad
);
// ------------------------------------------------------------------------------------
void
tanh
(
...
...
dlib/dnn/cuda_dlib.cu
View file @
dcb5b46b
...
...
@@ -538,7 +538,69 @@ namespace dlib
launch_kernel
(
_cuda_dot
,
max_jobs
(
a
.
size
()),
a
.
device
(),
b
.
device
(),
a
.
size
(),
result
.
device
()
+
idx
);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_prelu
(
const
float
*
s
,
float
*
d
,
size_t
n
,
const
float
*
pp
)
{
const
float
p
=
*
pp
;
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
{
if
(
s
[
i
]
>
0
)
d
[
i
]
=
s
[
i
];
else
d
[
i
]
=
p
*
s
[
i
];
}
}
void
prelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
tensor
&
param
)
{
launch_kernel
(
_cuda_prelu
,
max_jobs
(
dest
.
size
()),
src
.
device
(),
dest
.
device
(),
src
.
size
(),
param
.
device
());
}
// ----------------------------------------------------------------------------------------
__global__
void
_cuda_prelu_gradient
(
float
*
out
,
const
float
*
s
,
const
float
*
gi
,
size_t
n
,
const
float
*
pp
,
float
*
ppgrad
)
{
const
float
p
=
*
pp
;
float
pgrad
=
0
;
for
(
auto
i
:
grid_stride_range
(
0
,
n
))
{
if
(
s
[
i
]
>
0
)
{
out
[
i
]
+=
gi
[
i
];
}
else
{
out
[
i
]
+=
p
*
gi
[
i
];
pgrad
+=
gi
[
i
]
*
s
[
i
];
}
}
// Then do the warp reduce add thing to merge into one output value.
warp_reduce_atomic_add
(
*
ppgrad
,
pgrad
);
}
void
prelu_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
,
const
tensor
&
param
,
tensor
&
params_grad
)
{
params_grad
=
0
;
launch_kernel
(
_cuda_prelu_gradient
,
max_jobs
(
grad
.
size
()),
grad
.
device
(),
src
.
device
(),
gradient_input
.
device
(),
grad
.
size
(),
param
.
device
(),
params_grad
.
device
());
}
// ----------------------------------------------------------------------------------------
}
}
dlib/dnn/cuda_dlib.h
View file @
dcb5b46b
...
...
@@ -135,6 +135,22 @@ namespace dlib
size_t
idx
);
// ----------------------------------------------------------------------------------------
void
prelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
tensor
&
param
);
void
prelu_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
,
const
tensor
&
param
,
tensor
&
params_grad
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/layers.h
View file @
dcb5b46b
...
...
@@ -1119,6 +1119,73 @@ namespace dlib
template
<
typename
SUBNET
>
using
relu
=
add_layer
<
relu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
prelu_
{
public:
explicit
prelu_
(
float
initial_param_value_
=
0.25
)
:
initial_param_value
(
initial_param_value_
)
{
}
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
/*sub*/
)
{
params
.
set_size
(
1
);
params
=
initial_param_value
;
}
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
data_output
)
{
data_output
.
copy_size
(
sub
.
get_output
());
tt
::
prelu
(
data_output
,
sub
.
get_output
(),
params
);
}
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
params_grad
)
{
tt
::
prelu_gradient
(
sub
.
get_gradient_input
(),
sub
.
get_output
(),
gradient_input
,
params
,
params_grad
);
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
prelu_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"prelu_"
,
out
);
serialize
(
item
.
params
,
out
);
serialize
(
item
.
initial_param_value
,
out
);
}
friend
void
deserialize
(
prelu_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"prelu_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::prelu_."
);
deserialize
(
item
.
params
,
in
);
deserialize
(
item
.
initial_param_value
,
in
);
}
private:
resizable_tensor
params
;
float
initial_param_value
;
};
template
<
typename
SUBNET
>
using
prelu
=
add_layer
<
prelu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
sig_
...
...
dlib/dnn/layers_abstract.h
View file @
dcb5b46b
...
...
@@ -1066,6 +1066,55 @@ namespace dlib
template
<
typename
SUBNET
>
using
relu
=
add_layer
<
relu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
prelu_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_LAYER_ interface defined above.
In particular, it defines a parametric rectified linear layer. Therefore,
it passes its inputs through the function
f(x) = x>0 ? x : p*x
where f() is applied pointwise across the input tensor and p is a scalar
parameter learned by this layer.
This is the layer type introduced in the paper:
He, Kaiming, et al. "Delving deep into rectifiers: Surpassing
human-level performance on imagenet classification." Proceedings of the
IEEE International Conference on Computer Vision. 2015.
!*/
public:
explicit
prelu_
(
float
initial_param_value
=
0.25
);
/*!
ensures
- The p parameter will be initialized with initial_param_value.
!*/
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
);
void
forward_inplace
(
const
tensor
&
input
,
tensor
&
output
);
void
backward_inplace
(
const
tensor
&
computed_output
,
const
tensor
&
gradient_input
,
tensor
&
data_grad
,
tensor
&
params_grad
);
const
tensor
&
get_layer_params
()
const
;
tensor
&
get_layer_params
();
/*!
These functions are implemented as described in the EXAMPLE_LAYER_ interface.
!*/
};
void
serialize
(
const
prelu_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
prelu_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
prelu
=
add_layer
<
prelu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
sig_
...
...
dlib/dnn/tensor_tools.cpp
View file @
dcb5b46b
...
...
@@ -515,6 +515,36 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
void
prelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
tensor
&
param
)
{
#ifdef DLIB_USE_CUDA
cuda
::
prelu
(
dest
,
src
,
param
);
#else
cpu
::
prelu
(
dest
,
src
,
param
);
#endif
}
void
prelu_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
,
const
tensor
&
param
,
tensor
&
params_grad
)
{
#ifdef DLIB_USE_CUDA
cuda
::
prelu_gradient
(
grad
,
src
,
gradient_input
,
param
,
params_grad
);
#else
cpu
::
prelu_gradient
(
grad
,
src
,
gradient_input
,
param
,
params_grad
);
#endif
}
// ----------------------------------------------------------------------------------------
void
tanh
(
...
...
dlib/dnn/tensor_tools.h
View file @
dcb5b46b
...
...
@@ -896,6 +896,48 @@ namespace dlib { namespace tt
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
void
prelu
(
tensor
&
dest
,
const
tensor
&
src
,
const
tensor
&
param
);
/*!
requires
- have_same_dimensions(dest, src) == true
- param.size() == 1
ensures
- for all valid i:
- if (src.host()[i] > 0) then
- #dest.host()[i] == src.host()[i]
- else
- #dest.host()[i] == src.host()[i] * param.host()[0]
- This function supports in-place operation, i.e. having
is_same_object(dest, src)==true
!*/
void
prelu_gradient
(
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
gradient_input
,
const
tensor
&
param
,
tensor
&
params_grad
);
/*!
requires
- have_same_dimensions(grad,src) == true
- have_same_dimensions(grad,gradient_input) == true
- param.size() == 1
- params_grad.size() == 1
ensures
- Recalling that dest is the output of prelu(dest,src,param) let
f(src,param) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to src and
param. It assigns the gradient with respect to param to #params_grad and
adds the gradient with respect to src to #grad.
!*/
// ----------------------------------------------------------------------------------------
void
tanh
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment