Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
f4b3c7ee
"tests/vscode:/vscode.git/clone" did not exist on "729375741013ad93daf6afb7c73f8b874e0168cb"
Commit
f4b3c7ee
authored
Dec 17, 2016
by
Davis King
Browse files
Improved example
parent
f28d2f73
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
7 deletions
+11
-7
examples/dnn_metric_learning_ex.cpp
examples/dnn_metric_learning_ex.cpp
+11
-7
No files found.
examples/dnn_metric_learning_ex.cpp
View file @
f4b3c7ee
...
...
@@ -14,7 +14,6 @@
space it's very easy to do face recognition with some kind of k-nearest
neighbor classifier.
To keep this example as simple as possible we won't do face recognition.
Instead, we will create a very simple network and use it to learn a mapping
from 8D vectors to 2D vectors such that vectors with the same class labels
...
...
@@ -65,15 +64,20 @@ int main() try
// vectors.
using
net_type
=
loss_metric
<
fc
<
2
,
input
<
matrix
<
double
,
0
,
1
>>>>
;
net_type
net
;
// Now setup the trainer and train the network using our data.
dnn_trainer
<
net_type
>
trainer
(
net
);
trainer
.
set_learning_rate
(
0.1
);
trainer
.
set_min_learning_rate
(
0.001
);
trainer
.
set_mini_batch_size
(
128
);
trainer
.
be_verbose
();
trainer
.
set_iterations_without_progress_threshold
(
100
);
trainer
.
train
(
samples
,
labels
);
// It should be emphasized out that it's really important that each mini-batch contain
// multiple instances of each class of object. This is because the metric learning
// algorithm needs to consider pairs of objects that should be close as well as pairs
// of objects that should be far apart during each training step. Here we just keep
// training on the same small batch so this constraint is trivially satisfied.
while
(
trainer
.
get_learning_rate
()
>=
1e-4
)
trainer
.
train_one_step
(
samples
,
labels
);
// Wait for training threads to stop
trainer
.
get_net
();
cout
<<
"done training"
<<
endl
;
// Run all the samples through the network to get their 2D vector embeddings.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment