Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
6c36592c
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b7b1a30bc49cad350c7a642e1171e886d83cd909"
Commit
6c36592c
authored
Oct 15, 2015
by
Davis King
Browse files
Added serialization support to everything.
parent
e679d66a
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
437 additions
and
69 deletions
+437
-69
dlib/dnn/core.h
dlib/dnn/core.h
+147
-1
dlib/dnn/core_abstract.h
dlib/dnn/core_abstract.h
+38
-0
dlib/dnn/input.h
dlib/dnn/input.h
+28
-0
dlib/dnn/input_abstract.h
dlib/dnn/input_abstract.h
+14
-0
dlib/dnn/layers.h
dlib/dnn/layers.h
+32
-67
dlib/dnn/layers_abstract.h
dlib/dnn/layers_abstract.h
+16
-0
dlib/dnn/loss.h
dlib/dnn/loss.h
+26
-0
dlib/dnn/loss_abstract.h
dlib/dnn/loss_abstract.h
+12
-0
dlib/dnn/solvers.h
dlib/dnn/solvers.h
+21
-0
dlib/dnn/solvers_abstract.h
dlib/dnn/solvers_abstract.h
+12
-0
dlib/dnn/tensor.h
dlib/dnn/tensor.h
+56
-0
dlib/dnn/trainer.h
dlib/dnn/trainer.h
+27
-1
dlib/dnn/trainer_abstract.h
dlib/dnn/trainer_abstract.h
+8
-0
No files found.
dlib/dnn/core.h
View file @
6c36592c
...
...
@@ -67,6 +67,18 @@ namespace dlib
const
sstack
<
T
,
N
-
1
>&
pop
()
const
{
return
data
;
}
sstack
<
T
,
N
-
1
>&
pop
()
{
return
data
;
}
friend
void
serialize
(
const
sstack
&
item
,
std
::
ostream
&
out
)
{
serialize
(
item
.
top
(),
out
);
serialize
(
item
.
pop
(),
out
);
}
friend
void
deserialize
(
sstack
&
item
,
std
::
istream
&
in
)
{
deserialize
(
item
.
top
(),
in
);
deserialize
(
item
.
pop
(),
in
);
}
private:
T
item
;
sstack
<
T
,
N
-
1
>
data
;
...
...
@@ -83,6 +95,17 @@ namespace dlib
T
&
top
()
{
return
item
;
}
size_t
size
()
const
{
return
1
;
}
friend
void
serialize
(
const
sstack
&
item
,
std
::
ostream
&
out
)
{
serialize
(
item
.
top
(),
out
);
}
friend
void
deserialize
(
sstack
&
item
,
std
::
istream
&
in
)
{
deserialize
(
item
.
top
(),
in
);
}
private:
T
item
;
};
...
...
@@ -294,6 +317,32 @@ namespace dlib
subnetwork
.
clean
();
}
friend
void
serialize
(
const
add_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
subnetwork
,
out
);
serialize
(
item
.
details
,
out
);
serialize
(
item
.
this_layer_setup_called
,
out
);
serialize
(
item
.
gradient_input_is_stale
,
out
);
serialize
(
item
.
x_grad
,
out
);
serialize
(
item
.
cached_output
,
out
);
}
friend
void
deserialize
(
add_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_layer."
);
deserialize
(
item
.
subnetwork
,
in
);
deserialize
(
item
.
details
,
in
);
deserialize
(
item
.
this_layer_setup_called
,
in
);
deserialize
(
item
.
gradient_input_is_stale
,
in
);
deserialize
(
item
.
x_grad
,
in
);
deserialize
(
item
.
cached_output
,
in
);
}
private:
...
...
@@ -468,6 +517,32 @@ namespace dlib
gradient_input_is_stale
=
true
;
}
friend
void
serialize
(
const
add_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
input_layer
,
out
);
serialize
(
item
.
details
,
out
);
serialize
(
item
.
this_layer_setup_called
,
out
);
serialize
(
item
.
gradient_input_is_stale
,
out
);
serialize
(
item
.
x_grad
,
out
);
serialize
(
item
.
cached_output
,
out
);
}
friend
void
deserialize
(
add_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_layer."
);
deserialize
(
item
.
input_layer
,
in
);
deserialize
(
item
.
details
,
in
);
deserialize
(
item
.
this_layer_setup_called
,
in
);
deserialize
(
item
.
gradient_input_is_stale
,
in
);
deserialize
(
item
.
x_grad
,
in
);
deserialize
(
item
.
cached_output
,
in
);
}
private:
class
subnet_wrapper
...
...
@@ -601,6 +676,22 @@ namespace dlib
subnetwork
.
clean
();
}
friend
void
serialize
(
const
add_tag_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
subnetwork
,
out
);
}
friend
void
deserialize
(
add_tag_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_tag_layer."
);
deserialize
(
item
.
subnetwork
,
in
);
}
private:
subnet_type
subnetwork
;
...
...
@@ -702,6 +793,26 @@ namespace dlib
cached_output
.
clear
();
}
friend
void
serialize
(
const
add_tag_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
input_layer
,
out
);
serialize
(
item
.
cached_output
,
out
);
serialize
(
item
.
grad_final_ignored
,
out
);
}
friend
void
deserialize
(
add_tag_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_tag_layer."
);
deserialize
(
item
.
input_layer
,
in
);
deserialize
(
item
.
cached_output
,
in
);
deserialize
(
item
.
grad_final_ignored
,
in
);
}
private:
subnet_type
input_layer
;
...
...
@@ -759,7 +870,8 @@ namespace dlib
const
static
unsigned
int
sample_expansion_factor
=
subnet_type
::
sample_expansion_factor
;
typedef
typename
get_loss_layer_label_type
<
LOSS_DETAILS
>::
type
label_type
;
static_assert
(
is_nonloss_layer_type
<
SUBNET
>::
value
,
"SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer."
);
static_assert
(
is_nonloss_layer_type
<
SUBNET
>::
value
,
"SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer."
);
static_assert
(
sample_expansion_factor
==
LOSS_DETAILS
::
sample_expansion_factor
,
"The loss layer and input layer must agree on the sample_expansion_factor."
);
...
...
@@ -947,6 +1059,24 @@ namespace dlib
subnetwork
.
clear
();
}
friend
void
serialize
(
const
add_loss_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
loss
,
out
);
serialize
(
item
.
subnetwork
,
out
);
}
friend
void
deserialize
(
add_loss_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_loss_layer."
);
deserialize
(
item
.
loss
,
in
);
deserialize
(
item
.
subnetwork
,
in
);
}
private:
loss_details_type
loss
;
...
...
@@ -1150,6 +1280,22 @@ namespace dlib
subnetwork
.
clean
();
}
friend
void
serialize
(
const
add_skip_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
subnetwork
,
out
);
}
friend
void
deserialize
(
add_skip_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_skip_layer."
);
deserialize
(
item
.
subnetwork
,
in
);
}
private:
subnet_type
subnetwork
;
...
...
dlib/dnn/core_abstract.h
View file @
6c36592c
...
...
@@ -119,6 +119,12 @@ namespace dlib
!*/
};
void
serialize
(
const
sstack
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
sstack
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
template
<
...
...
@@ -378,6 +384,14 @@ namespace dlib
};
template
<
typename
T
,
typename
U
>,
void
serialize
(
const
add_layer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
,
typename
U
>,
void
deserialize
(
add_layer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
@@ -769,6 +783,14 @@ namespace dlib
!*/
};
template
<
typename
T
,
typename
U
>,
void
serialize
(
const
add_loss_layer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
,
typename
U
>,
void
deserialize
(
add_loss_layer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
@@ -799,6 +821,14 @@ namespace dlib
!*/
};
template
<
unsigned
long
ID
,
typename
U
>,
void
serialize
(
const
add_tag_layer
<
ID
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
unsigned
long
ID
,
typename
U
>,
void
deserialize
(
add_tag_layer
<
ID
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
tag1
=
add_tag_layer
<
1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
tag2
=
add_tag_layer
<
2
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
tag3
=
add_tag_layer
<
3
,
SUBNET
>
;
...
...
@@ -834,6 +864,14 @@ namespace dlib
!*/
};
template
<
template
<
typename
>
class
T
,
typename
U
>
void
serialize
(
const
add_skip_layer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
template
<
typename
>
class
T
,
typename
U
>
void
deserialize
(
add_skip_layer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
skip1
=
add_skip_layer
<
tag1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
skip2
=
add_skip_layer
<
tag2
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
skip3
=
add_skip_layer
<
tag3
,
SUBNET
>
;
...
...
dlib/dnn/input.h
View file @
6c36592c
...
...
@@ -73,6 +73,20 @@ namespace dlib
}
}
friend
void
serialize
(
const
input
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"input<matrix>"
,
out
);
}
friend
void
deserialize
(
input
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"input<matrix>"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::input."
);
}
};
// ----------------------------------------------------------------------------------------
...
...
@@ -126,6 +140,20 @@ namespace dlib
}
}
friend
void
serialize
(
const
input
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"input<array2d>"
,
out
);
}
friend
void
deserialize
(
input
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"input<array2d>"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::input."
);
}
};
// ----------------------------------------------------------------------------------------
...
...
dlib/dnn/input_abstract.h
View file @
6c36592c
...
...
@@ -86,6 +86,12 @@ namespace dlib
!*/
};
void
serialize
(
const
EXAMPLE_INPUT_LAYER
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_INPUT_LAYER
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
template
<
...
...
@@ -132,6 +138,14 @@ namespace dlib
!*/
};
template
<
typename
T
>
void
serialize
(
const
input
<
T
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
>
void
deserialize
(
input
<
T
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/layers.h
View file @
6c36592c
...
...
@@ -59,13 +59,12 @@ namespace dlib
public:
fc_
()
:
num_outputs
(
1
)
{
rnd
.
set_seed
(
"fc_"
+
cast_to_string
(
num_outputs
));
}
explicit
fc_
(
unsigned
long
num_outputs_
)
explicit
fc_
(
unsigned
long
num_outputs_
)
:
num_outputs
(
num_outputs_
)
{
num_outputs
=
num_outputs_
;
rnd
.
set_seed
(
"fc_"
+
cast_to_string
(
num_outputs
));
}
unsigned
long
get_num_outputs
(
...
...
@@ -77,6 +76,7 @@ namespace dlib
num_inputs
=
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
();
params
.
set_size
(
num_inputs
,
num_outputs
);
dlib
::
rand
rnd
(
"fc_"
+
cast_to_string
(
num_outputs
));
randomize_parameters
(
params
,
num_inputs
+
num_outputs
,
rnd
);
}
...
...
@@ -101,12 +101,30 @@ namespace dlib
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
fc_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"fc_"
,
out
);
serialize
(
item
.
num_outputs
,
out
);
serialize
(
item
.
num_inputs
,
out
);
serialize
(
item
.
params
,
out
);
}
friend
void
deserialize
(
fc_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"fc_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::fc_."
);
deserialize
(
item
.
num_outputs
,
in
);
deserialize
(
item
.
num_inputs
,
in
);
deserialize
(
item
.
params
,
in
);
}
private:
unsigned
long
num_outputs
;
unsigned
long
num_inputs
;
resizable_tensor
params
;
dlib
::
rand
rnd
;
};
...
...
@@ -151,81 +169,28 @@ namespace dlib
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
private:
resizable_tensor
params
;
};
template
<
typename
SUBNET
>
using
relu
=
add_layer
<
relu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
multiply_
{
public:
multiply_
()
{
}
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
)
friend
void
serialize
(
const
relu_
&
item
,
std
::
ostream
&
out
)
{
num_inputs
=
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
();
params
.
set_size
(
1
,
num_inputs
);
std
::
cout
<<
"multiply_::setup() "
<<
params
.
size
()
<<
std
::
endl
;
const
int
num_outputs
=
num_inputs
;
randomize_parameters
(
params
,
num_inputs
+
num_outputs
,
rnd
);
serialize
(
"relu_"
,
out
);
}
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
output
)
{
DLIB_CASSERT
(
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
()
==
params
.
size
(),
""
);
DLIB_CASSERT
(
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
()
==
num_inputs
,
""
);
output
.
copy_size
(
sub
.
get_output
());
auto
indata
=
sub
.
get_output
().
host
();
auto
outdata
=
output
.
host
();
auto
paramdata
=
params
.
host
();
for
(
int
i
=
0
;
i
<
sub
.
get_output
().
num_samples
();
++
i
)
{
for
(
int
j
=
0
;
j
<
num_inputs
;
++
j
)
{
*
outdata
++
=
*
indata
++
*
paramdata
[
j
];
}
}
}
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
params_grad
)
friend
void
deserialize
(
relu_
&
item
,
std
::
istream
&
in
)
{
params_grad
+=
sum_rows
(
pointwise_multiply
(
mat
(
sub
.
get_output
()),
mat
(
gradient_input
)));
for
(
long
i
=
0
;
i
<
gradient_input
.
num_samples
();
++
i
)
{
sub
.
get_gradient_input
().
add_to_sample
(
i
,
pointwise_multiply
(
rowm
(
mat
(
gradient_input
),
i
),
mat
(
params
)));
}
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"relu_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::relu_."
);
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
private:
int
num_inputs
;
resizable_tensor
params
;
dlib
::
rand
rnd
;
};
template
<
typename
SUBNET
>
using
multiply
=
add_layer
<
multiply
_
,
SUBNET
>
;
using
relu
=
add_layer
<
relu
_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
...
...
dlib/dnn/layers_abstract.h
View file @
6c36592c
...
...
@@ -218,6 +218,12 @@ namespace dlib
};
void
serialize
(
const
EXAMPLE_LAYER_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_LAYER_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// For each layer you define, always define an add_layer template so that layers can be
// easily composed. Moreover, the convention is that the layer class ends with an _
// while the add_layer template has the same name but without the trailing _.
...
...
@@ -274,6 +280,11 @@ namespace dlib
!*/
};
void
serialize
(
const
fc_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
fc_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
fc
=
add_layer
<
fc_
,
SUBNET
>
;
...
...
@@ -306,6 +317,11 @@ namespace dlib
!*/
};
void
serialize
(
const
relu_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
relu_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
relu
=
add_layer
<
relu_
,
SUBNET
>
;
...
...
dlib/dnn/loss.h
View file @
6c36592c
...
...
@@ -81,6 +81,19 @@ namespace dlib
return
loss
;
}
friend
void
serialize
(
const
loss_binary_hinge_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"loss_binary_hinge_"
,
out
);
}
friend
void
deserialize
(
loss_binary_hinge_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"loss_binary_hinge_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::loss_binary_hinge_."
);
}
};
template
<
typename
SUBNET
>
...
...
@@ -105,6 +118,19 @@ namespace dlib
return
0
;
}
friend
void
serialize
(
const
loss_no_label_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"loss_no_label_"
,
out
);
}
friend
void
deserialize
(
loss_no_label_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"loss_no_label_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::loss_no_label_."
);
}
};
template
<
typename
SUBNET
>
...
...
dlib/dnn/loss_abstract.h
View file @
6c36592c
...
...
@@ -118,6 +118,12 @@ namespace dlib
!*/
};
void
serialize
(
const
EXAMPLE_LOSS_LAYER_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_LOSS_LAYER_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// For each loss layer you define, always define an add_loss_layer template so that
// layers can be easily composed. Moreover, the convention is that the layer class
// ends with an _ while the add_loss_layer template has the same name but without the
...
...
@@ -179,6 +185,12 @@ namespace dlib
};
void
serialize
(
const
loss_binary_hinge_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
loss_binary_hinge_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
loss_binary_hinge
=
add_loss_layer
<
loss_binary_hinge_
,
SUBNET
>
;
...
...
dlib/dnn/solvers.h
View file @
6c36592c
...
...
@@ -48,6 +48,27 @@ namespace dlib
l
.
get_layer_params
()
+=
v
;
}
friend
void
serialize
(
const
sgd
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"sgd"
,
out
);
serialize
(
item
.
v
,
out
);
serialize
(
item
.
weight_decay
,
out
);
serialize
(
item
.
learning_rate
,
out
);
serialize
(
item
.
momentum
,
out
);
}
friend
void
deserialize
(
sgd
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"sgd"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::sgd."
);
deserialize
(
item
.
v
,
in
);
deserialize
(
item
.
weight_decay
,
in
);
deserialize
(
item
.
learning_rate
,
in
);
deserialize
(
item
.
momentum
,
in
);
}
private:
matrix
<
float
>
v
;
float
weight_decay
;
...
...
dlib/dnn/solvers_abstract.h
View file @
6c36592c
...
...
@@ -52,6 +52,12 @@ namespace dlib
!*/
};
void
serialize
(
const
EXAMPLE_SOLVER
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_SOLVER
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
@@ -92,6 +98,12 @@ namespace dlib
float
get_momentum
()
const
;
};
void
serialize
(
const
sgd
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
sgd
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/tensor.h
View file @
6c36592c
...
...
@@ -112,6 +112,7 @@ namespace dlib
size_t
size
()
const
{
return
data_size
;
}
private:
void
copy_to_device
()
const
...
...
@@ -144,6 +145,30 @@ namespace dlib
std
::
unique_ptr
<
float
[]
>
data_device
;
};
inline
void
serialize
(
const
gpu_data
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
item
.
size
(),
out
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
serialize
(
data
[
i
],
out
);
}
inline
void
deserialize
(
gpu_data
&
item
,
std
::
istream
&
in
)
{
int
version
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::gpu_data."
);
size_t
s
;
deserialize
(
s
,
in
);
item
.
set_size
(
s
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
deserialize
(
data
[
i
],
in
);
}
// ----------------------------------------------------------------------------------------
class
tensor
...
...
@@ -466,6 +491,37 @@ namespace dlib
}
};
inline
void
serialize
(
const
tensor
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
num_samples
(),
out
);
serialize
(
item
.
nr
(),
out
);
serialize
(
item
.
nc
(),
out
);
serialize
(
item
.
k
(),
out
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
serialize
(
data
[
i
],
out
);
}
inline
void
deserialize
(
resizable_tensor
&
item
,
std
::
istream
&
in
)
{
int
version
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::resizable_tensor."
);
long
num_samples
=
0
,
nr
=
0
,
nc
=
0
,
k
=
0
;
deserialize
(
num_samples
,
in
);
deserialize
(
nr
,
in
);
deserialize
(
nc
,
in
);
deserialize
(
k
,
in
);
item
.
set_size
(
num_samples
,
nr
,
nc
,
k
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
deserialize
(
data
[
i
],
in
);
}
// ----------------------------------------------------------------------------------------
inline
double
dot
(
...
...
dlib/dnn/trainer.h
View file @
6c36592c
...
...
@@ -9,6 +9,7 @@
#include "../statistics.h"
#include "../console_progress_indicator.h"
#include <chrono>
#include "../serialize.h"
namespace
dlib
{
...
...
@@ -281,8 +282,34 @@ namespace dlib
return
net
;
}
friend
void
serialize
(
const
dnn_trainer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
num_epochs
,
out
);
serialize
(
item
.
mini_batch_size
,
out
);
serialize
(
item
.
verbose
,
out
);
serialize
(
item
.
net
,
out
);
serialize
(
item
.
solvers
,
out
);
}
friend
void
deserialize
(
dnn_trainer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::dnn_trainer."
);
deserialize
(
item
.
num_epochs
,
in
);
deserialize
(
item
.
mini_batch_size
,
in
);
deserialize
(
item
.
verbose
,
in
);
deserialize
(
item
.
net
,
in
);
deserialize
(
item
.
solvers
,
in
);
}
private:
const
static
long
string_pad
=
10
;
void
init
()
{
num_epochs
=
300
;
...
...
@@ -293,7 +320,6 @@ namespace dlib
unsigned
long
num_epochs
;
unsigned
long
mini_batch_size
;
bool
verbose
;
const
static
long
string_pad
=
10
;
net_type
net
;
sstack
<
solver_type
,
net_type
::
num_layers
>
solvers
;
...
...
dlib/dnn/trainer_abstract.h
View file @
6c36592c
...
...
@@ -222,6 +222,14 @@ namespace dlib
};
template
<
typename
T
,
typename
U
>
void
serialize
(
const
dnn_trainer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
,
typename
U
>
void
deserialize
(
dnn_trainer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment