model_selection_ex.cpp 6.7 KB
Newer Older
1
2
3
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*

4
5
    This is an example that shows how you can perform model selection with the
    dlib C++ Library.  
6

7
8
9
    It will create a simple dataset and show you how to use cross validation and
    global optimization to determine good parameters for the purpose of training
    an svm to classify the data.
10

11
12
13
    The data used in this example will be 2 dimensional data and will come from a
    distribution where points with a distance less than 10 from the origin are
    labeled +1 and all other points are labeled as -1.
14
        
15

16
    As an side, you should probably read the svm_ex.cpp and matrix_ex.cpp example
17
    programs before you read this one.
18
19
20
21
*/


#include <iostream>
22
#include <dlib/svm.h>
23
#include <dlib/global_optimization.h>
24
25
26
27
28

using namespace std;
using namespace dlib;


29
30
31
32
int main() try
{
    // The svm functions use column vectors to contain a lot of the data on which they 
    // operate. So the first thing we do here is declare a convenient typedef.  
33

34
35
36
    // This typedef declares a matrix with 2 rows and 1 column.  It will be the
    // object that contains each of our 2 dimensional samples.   
    typedef matrix<double, 2, 1> sample_type;
37
38
39



40
41
42
    // Now we make objects to contain our samples and their respective labels.
    std::vector<sample_type> samples;
    std::vector<double> labels;
43

44
45
46
47
    // Now let's put some data into our samples and labels objects.  We do this
    // by looping over a bunch of points and labeling them according to their
    // distance from the origin.
    for (double r = -20; r <= 20; r += 0.8)
48
    {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        for (double c = -20; c <= 20; c += 0.8)
        {
            sample_type samp;
            samp(0) = r;
            samp(1) = c;
            samples.push_back(samp);

            // if this point is less than 10 from the origin
            if (sqrt(r*r + c*c) <= 10)
                labels.push_back(+1);
            else
                labels.push_back(-1);
        }
    }
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    cout << "Generated " << samples.size() << " points" << endl;


    // Here we normalize all the samples by subtracting their mean and dividing by their
    // standard deviation.  This is generally a good idea since it often heads off
    // numerical stability problems and also prevents one large feature from smothering
    // others.  Doing this doesn't matter much in this example so I'm just doing this here
    // so you can see an easy way to accomplish this with the library.  
    vector_normalizer<sample_type> normalizer;
    // let the normalizer learn the mean and standard deviation of the samples
    normalizer.train(samples);
    // now normalize each sample
    for (unsigned long i = 0; i < samples.size(); ++i)
        samples[i] = normalizer(samples[i]); 


    // Now that we have some data we want to train on it.  We are going to train a
81
82
83
84
85
    // binary SVM with the RBF kernel to classify the data.  However, there are
    // three parameters to the training.  These are the SVM C parameters for each
    // class and the RBF kernel's gamma parameter.  Our choice for these
    // parameters will influence how good the resulting decision function is.  To
    // test how good a particular choice of these parameters is we can use the
86
    // cross_validate_trainer() function to perform n-fold cross validation on our
87
88
89
90
91
92
    // training data.  However, there is a problem with the way we have sampled
    // our distribution above.  The problem is that there is a definite ordering
    // to the samples.  That is, the first half of the samples look like they are
    // from a different distribution than the second half.  This would screw up
    // the cross validation process, but we can fix it by randomizing the order of
    // the samples with the following function call.
93
94
95
96
97
98
    randomize_samples(samples, labels);


    // And now we get to the important bit.  Here we define a function,
    // cross_validation_score(), that will do the cross-validation we
    // mentioned and return a number indicating how good a particular setting
99
100
    // of gamma, c1, and c2 is.
    auto cross_validation_score = [&](const double gamma, const double c1, const double c2) 
101
102
103
    {
        // Make a RBF SVM trainer and tell it what the parameters are supposed to be.
        typedef radial_basis_kernel<sample_type> kernel_type;
104
        svm_c_trainer<kernel_type> trainer;
105
        trainer.set_kernel(kernel_type(gamma));
106
107
        trainer.set_c_class1(c1);
        trainer.set_c_class2(c2);
108

109
        // Finally, perform 10-fold cross validation and then print and return the results.
110
        matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);
111
        cout << "gamma: " << setw(11) << gamma << "  c1: " << setw(11) << c1 <<  "  c2: " << setw(11) << c2 <<  "  cross validation accuracy: " << result;
112

113
114
115
116
117
118
119
        // Now return a number indicating how good the parameters are.  Bigger is
        // better in this example.  Here I'm returning the harmonic mean between the
        // accuracies of each class.  However, you could do something else.  For
        // example, you might care a lot more about correctly predicting the +1 class,
        // so you could penalize results that didn't obtain a high accuracy on that
        // class.  You might do this by using something like a weighted version of the
        // F1-score (see http://en.wikipedia.org/wiki/F1_score).     
120
        return 2*prod(result)/sum(result);
121
    };
122
123


124
    // And finally, we call this global optimizer that will search for the best parameters.
125
    // It will call cross_validation_score() 30 times with different settings and return
126
127
128
129
130
    // the best parameter setting it finds.  find_max_global() uses a global optimization
    // method based on a combination of non-parametric global function modeling and
    // quadratic trust region modeling to efficiently find a global maximizer.  It usually
    // does a good job with a relatively small number of calls to cross_validation_score().
    // In this example, you should observe that it finds settings that give perfect binary
131
    // classification of the data.
132
    auto result = find_max_global(cross_validation_score, 
133
134
135
                                  {1e-5, 1e-5, 1e-5},  // lower bound constraints on gamma, c1, and c2, respectively
                                  {100,  1e6,  1e6},   // upper bound constraints on gamma, c1, and c2, respectively
                                  max_function_calls(30));
136

137
    double best_gamma = result.x(0);
138
139
    double best_c1    = result.x(1);
    double best_c2    = result.x(2);
140

141
    cout << " best cross-validation score: " << result.y << endl;
142
    cout << " best gamma: " << best_gamma << "   best c1: " << best_c1 << "    best c2: "<< best_c2  << endl;
143
144
145
146
}
catch (exception& e)
{
    cout << e.what() << endl;
147
148
}