"docs/vscode:/vscode.git/clone" did not exist on "9a7235faf2835d424c4587e703024248e6b9f465"
model_selection_ex.cpp 6.87 KB
Newer Older
1
2
3
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*

4
5
    This is an example that shows how you can perform model selection with the
    dlib C++ Library.  
6

7
8
9
    It will create a simple dataset and show you how to use cross validation and
    global optimization to determine good parameters for the purpose of training
    an svm to classify the data.
10

11
12
13
    The data used in this example will be 2 dimensional data and will come from a
    distribution where points with a distance less than 10 from the origin are
    labeled +1 and all other points are labeled as -1.
14
        
15

16
    As an side, you should probably read the svm_ex.cpp and matrix_ex.cpp example
17
    programs before you read this one.
18
19
20
21
*/


#include <iostream>
22
#include <dlib/svm.h>
23
#include <dlib/global_optimization.h>
24
25
26
27
28

using namespace std;
using namespace dlib;


29
30
31
32
int main() try
{
    // The svm functions use column vectors to contain a lot of the data on which they 
    // operate. So the first thing we do here is declare a convenient typedef.  
33

34
35
36
    // This typedef declares a matrix with 2 rows and 1 column.  It will be the
    // object that contains each of our 2 dimensional samples.   
    typedef matrix<double, 2, 1> sample_type;
37
38
39



40
41
42
    // Now we make objects to contain our samples and their respective labels.
    std::vector<sample_type> samples;
    std::vector<double> labels;
43

44
45
46
47
    // Now let's put some data into our samples and labels objects.  We do this
    // by looping over a bunch of points and labeling them according to their
    // distance from the origin.
    for (double r = -20; r <= 20; r += 0.8)
48
    {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        for (double c = -20; c <= 20; c += 0.8)
        {
            sample_type samp;
            samp(0) = r;
            samp(1) = c;
            samples.push_back(samp);

            // if this point is less than 10 from the origin
            if (sqrt(r*r + c*c) <= 10)
                labels.push_back(+1);
            else
                labels.push_back(-1);
        }
    }
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    cout << "Generated " << samples.size() << " points" << endl;


    // Here we normalize all the samples by subtracting their mean and dividing by their
    // standard deviation.  This is generally a good idea since it often heads off
    // numerical stability problems and also prevents one large feature from smothering
    // others.  Doing this doesn't matter much in this example so I'm just doing this here
    // so you can see an easy way to accomplish this with the library.  
    vector_normalizer<sample_type> normalizer;
    // let the normalizer learn the mean and standard deviation of the samples
    normalizer.train(samples);
    // now normalize each sample
    for (unsigned long i = 0; i < samples.size(); ++i)
        samples[i] = normalizer(samples[i]); 


    // Now that we have some data we want to train on it.  We are going to train a
    // binary SVM with the RBF kernel to classify the data.  However, there are two
    // parameters to the training.  These are the nu and gamma parameters.  Our choice
    // for these parameters will influence how good the resulting decision function is.
    // To test how good a particular choice of these parameters is we can use the
    // cross_validate_trainer() function to perform n-fold cross validation on our
    // training data.  However, there is a problem with the way we have sampled our
    // distribution above.  The problem is that there is a definite ordering to the
    // samples.  That is, the first half of the samples look like they are from a
    // different distribution than the second half.  This would screw up the cross
    // validation process, but we can fix it by randomizing the order of the samples
    // with the following function call.
    randomize_samples(samples, labels);



    // And now we get to the important bit.  Here we define a function,
    // cross_validation_score(), that will do the cross-validation we
    // mentioned and return a number indicating how good a particular setting
    // of gamma and nu is.
    auto cross_validation_score = [&](const double gamma, const double nu) 
    {
        // Make a RBF SVM trainer and tell it what the parameters are supposed to be.
        typedef radial_basis_kernel<sample_type> kernel_type;
104
105
106
107
        svm_nu_trainer<kernel_type> trainer;
        trainer.set_kernel(kernel_type(gamma));
        trainer.set_nu(nu);

108
        // Finally, perform 10-fold cross validation and then print and return the results.
109
110
        matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);
        cout << "gamma: " << setw(11) << gamma << "  nu: " << setw(11) << nu <<  "  cross validation accuracy: " << result;
111

112
113
114
115
116
117
118
        // Now return a number indicating how good the parameters are.  Bigger is
        // better in this example.  Here I'm returning the harmonic mean between the
        // accuracies of each class.  However, you could do something else.  For
        // example, you might care a lot more about correctly predicting the +1 class,
        // so you could penalize results that didn't obtain a high accuracy on that
        // class.  You might do this by using something like a weighted version of the
        // F1-score (see http://en.wikipedia.org/wiki/F1_score).     
119
        return 2*prod(result)/sum(result);
120
    };
121

122
123
124
125
126
    // The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1
    // labels in the training data.  This function finds that value.  The 0.999 is here
    // because the maximum allowable nu is strictly less than the value returned by
    // maximum_nu().  So shrinking the limit a little will prevent us from hitting it.
    const double max_nu = 0.999*maximum_nu(labels);
127
128


129
130
131
132
133
134
135
136
137
138
139
140
    // And finally, we call this global optimizer that will search for the best parameters.
    // It will call cross_validation_score() 50 times with different settings and return
    // the best parameter setting it finds.  find_max_global() uses a global optimization
    // method based on a combination of non-parametric global function modeling and
    // quadratic trust region modeling to efficiently find a global maximizer.  It usually
    // does a good job with a relatively small number of calls to cross_validation_score().
    // In this example, you should observe that it finds settings that give perfect binary
    // classification on the data.
    auto result = find_max_global(cross_validation_score, 
                                  {1e-5, 1e-5},  // lower bound constraints on gamma and nu, respectively
                                  {100, max_nu}, // upper bound constraints on gamma and nu, respectively
                                  max_function_calls(50));
141

142
143
    double best_gamma = result.x(0);
    double best_nu    = result.x(1);
144

145
146
    cout << " best cross-validation score: " << result.y << endl;
    cout << " best gamma: " << best_gamma << "   best nu: " << best_nu << endl;
147
148


149
150
151
152
}
catch (exception& e)
{
    cout << e.what() << endl;
153
154
}