Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
47bdf95f
"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "bfd7cee1940aee1c217b3e3a55c1de10eed1732c"
Commit
47bdf95f
authored
Mar 27, 2016
by
Davis King
Browse files
added more stuff to example
parent
bd79b877
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
5 deletions
+22
-5
examples/dnn_mnist_resnet_ex.cpp
examples/dnn_mnist_resnet_ex.cpp
+22
-5
No files found.
examples/dnn_mnist_resnet_ex.cpp
View file @
47bdf95f
...
@@ -60,8 +60,7 @@ int main(int argc, char** argv) try
...
@@ -60,8 +60,7 @@ int main(int argc, char** argv) try
);
);
//dnn_trainer<net_type,adam> trainer(net,adam(0.001));
dnn_trainer
<
net_type
,
adam
>
trainer
(
net
,
adam
(
0.001
));
dnn_trainer
<
net_type
>
trainer
(
net
,
sgd
(
0.1
));
trainer
.
be_verbose
();
trainer
.
be_verbose
();
trainer
.
set_synchronization_file
(
"mnist_resnet_sync"
,
std
::
chrono
::
seconds
(
100
));
trainer
.
set_synchronization_file
(
"mnist_resnet_sync"
,
std
::
chrono
::
seconds
(
100
));
std
::
vector
<
matrix
<
unsigned
char
>>
mini_batch_samples
;
std
::
vector
<
matrix
<
unsigned
char
>>
mini_batch_samples
;
...
@@ -86,11 +85,29 @@ int main(int argc, char** argv) try
...
@@ -86,11 +85,29 @@ int main(int argc, char** argv) try
// wait for threaded processing to stop.
// wait for threaded processing to stop.
trainer
.
get_net
();
trainer
.
get_net
();
// You can access sub layers of the network like this:
net
.
subnet
().
subnet
().
get_output
();
layer
<
avg_pool
>
(
net
).
get_output
();
net
.
clean
();
net
.
clean
();
serialize
(
"mnist_network.dat"
)
<<
net
;
serialize
(
"mnist_res_network.dat"
)
<<
net
;
typedef
loss_multiclass_log
<
fc
<
avg_pool
<
ares
<
ares
<
ares
<
ares
<
repeat
<
10
,
ares
,
ares
<
ares
<
input
<
matrix
<
unsigned
char
>
>>>>>>>>>>>
test_net_type
;
test_net_type
tnet
=
net
;
// or you could deserialize the saved network
deserialize
(
"mnist_res_network.dat"
)
>>
tnet
;
// Run the net on all the data to get predictions
// Run the net on all the data to get predictions
std
::
vector
<
unsigned
long
>
predicted_labels
=
net
(
training_images
);
std
::
vector
<
unsigned
long
>
predicted_labels
=
t
net
(
training_images
);
int
num_right
=
0
;
int
num_right
=
0
;
int
num_wrong
=
0
;
int
num_wrong
=
0
;
for
(
size_t
i
=
0
;
i
<
training_images
.
size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
training_images
.
size
();
++
i
)
...
@@ -105,7 +122,7 @@ int main(int argc, char** argv) try
...
@@ -105,7 +122,7 @@ int main(int argc, char** argv) try
cout
<<
"training num_wrong: "
<<
num_wrong
<<
endl
;
cout
<<
"training num_wrong: "
<<
num_wrong
<<
endl
;
cout
<<
"training accuracy: "
<<
num_right
/
(
double
)(
num_right
+
num_wrong
)
<<
endl
;
cout
<<
"training accuracy: "
<<
num_right
/
(
double
)(
num_right
+
num_wrong
)
<<
endl
;
predicted_labels
=
net
(
testing_images
);
predicted_labels
=
t
net
(
testing_images
);
num_right
=
0
;
num_right
=
0
;
num_wrong
=
0
;
num_wrong
=
0
;
for
(
size_t
i
=
0
;
i
<
testing_images
.
size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
testing_images
.
size
();
++
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment