dnn_metric_learning_ex.cpp 4.89 KB
Newer Older
Davis King's avatar
Davis King committed
1
2
3
4
5
6
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
    This is an example illustrating the use of the deep learning tools from the
    dlib C++ Library.  In it, we will show how to use the loss_metric layer to do
    metric learning.  

Davis King's avatar
Davis King committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    The main reason you might want to use this kind of algorithm is because you
    would like to use a k-nearest neighbor classifier or similar algorithm, but
    you don't know a good way to calculate the distance between two things.  A
    popular example would be face recognition.  There are a whole lot of papers
    that train some kind of deep metric learning algorithm that embeds face
    images in some vector space where images of the same person are close to each
    other and images of different people are far apart.  Then in that vector
    space it's very easy to do face recognition with some kind of k-nearest
    neighbor classifier.  
    

    To keep this example as simple as possible we won't do face recognition.
    Instead, we will create a very simple network and use it to learn a mapping
    from 8D vectors to 2D vectors such that vectors with the same class labels
    are near each other.  If you want to see a more complex example that learns
    the kind of network you would use for something like face recognition read
    the dnn_metric_learning_on_images_ex.cpp example.

    You should also have read the examples that introduce the dlib DNN API before 
    continuing.  These are dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp.
Davis King's avatar
Davis King committed
27
28
29
30
31
32
33
34
35
36
37
38
*/


#include <dlib/dnn.h>
#include <iostream>

using namespace std;
using namespace dlib;


int main() try
{
Davis King's avatar
Davis King committed
39
40
41
42
43
    // The API for doing metric learning is very similar to the API for
    // multi-class classification.  In fact, the inputs are the same, a bunch of
    // labeled objects.  So here we create our dataset.  We make up some simple
    // vectors and label them with the integers 1,2,3,4.  The specific values of
    // the integer labels don't matter.
Davis King's avatar
Davis King committed
44
45
46
    std::vector<matrix<double,0,1>> samples;
    std::vector<unsigned long> labels;

Davis King's avatar
Davis King committed
47
    // class 1 training vectors
Davis King's avatar
Davis King committed
48
49
50
    samples.push_back({1,0,0,0,0,0,0,0}); labels.push_back(1);
    samples.push_back({0,1,0,0,0,0,0,0}); labels.push_back(1);

Davis King's avatar
Davis King committed
51
    // class 2 training vectors
Davis King's avatar
Davis King committed
52
53
54
    samples.push_back({0,0,1,0,0,0,0,0}); labels.push_back(2);
    samples.push_back({0,0,0,1,0,0,0,0}); labels.push_back(2);

Davis King's avatar
Davis King committed
55
    // class 3 training vectors
Davis King's avatar
Davis King committed
56
57
58
    samples.push_back({0,0,0,0,1,0,0,0}); labels.push_back(3);
    samples.push_back({0,0,0,0,0,1,0,0}); labels.push_back(3);

Davis King's avatar
Davis King committed
59
    // class 4 training vectors
Davis King's avatar
Davis King committed
60
61
62
    samples.push_back({0,0,0,0,0,0,1,0}); labels.push_back(4);
    samples.push_back({0,0,0,0,0,0,0,1}); labels.push_back(4);

Davis King's avatar
Davis King committed
63
64
65
66
67
68
69
70
71
72
73
74

    // Make a network that simply learns a linear mapping from 8D vectors to 2D
    // vectors.
    using net_type = loss_metric<fc<2,input<matrix<double,0,1>>>>; 
    net_type net;
    // Now setup the trainer and train the network using our data.
    dnn_trainer<net_type> trainer(net);
    trainer.set_learning_rate(0.1);
    trainer.set_min_learning_rate(0.001);
    trainer.set_mini_batch_size(128);
    trainer.be_verbose();
    trainer.set_iterations_without_progress_threshold(100);
Davis King's avatar
Davis King committed
75
76
77
78
    trainer.train(samples, labels);



Davis King's avatar
Davis King committed
79
80
81
82
83
84
    // Run all the samples through the network to get their 2D vector embeddings.
    std::vector<matrix<float,0,1>> embedded = net(samples);

    // Print the embedding for each sample to the screen.  If you look at the
    // outputs carefully you should notice that they are grouped together in 2D
    // space according to their label.
Davis King's avatar
Davis King committed
85
86
87
    for (size_t i = 0; i < embedded.size(); ++i)
        cout << "label: " << labels[i] << "\t" << trans(embedded[i]);

88
89
    // Now, check if the embedding puts things with the same labels near each other and
    // things with different labels far apart.
Davis King's avatar
Davis King committed
90
91
92
93
94
95
96
97
    int num_right = 0;
    int num_wrong = 0;
    for (size_t i = 0; i < embedded.size(); ++i)
    {
        for (size_t j = i+1; j < embedded.size(); ++j)
        {
            if (labels[i] == labels[j])
            {
98
99
                // The loss_metric layer will cause things with the same label to be less
                // than net.loss_details().get_distance_threshold() distance from each
Davis King's avatar
Davis King committed
100
101
                // other.  So we can use that distance value as our testing threshold for
                // "being near to each other".
Davis King's avatar
Davis King committed
102
103
104
105
106
107
108
                if (length(embedded[i]-embedded[j]) < net.loss_details().get_distance_threshold())
                    ++num_right;
                else
                    ++num_wrong;
            }
            else
            {
109
                if (length(embedded[i]-embedded[j]) >= net.loss_details().get_distance_threshold())
Davis King's avatar
Davis King committed
110
                    ++num_right;
111
112
                else
                    ++num_wrong;
Davis King's avatar
Davis King committed
113
114
115
116
117
118
119
120
121
122
123
124
            }
        }
    }

    cout << "num_right: "<< num_right << endl;
    cout << "num_wrong: "<< num_wrong << endl;
}
catch(std::exception& e)
{
    cout << e.what() << endl;
}