rvm_regression_ex.cpp 3.68 KB
Newer Older
1
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
Davis King's avatar
Davis King committed
2
3
4
5
6
7
8
9
10
11
12
/*
    This is an example illustrating the use of the RVM regression object 
    from the dlib C++ Library.

    This example will train on data from the sinc function.

*/

#include <iostream>
#include <vector>

13
#include <dlib/svm.h>
Davis King's avatar
Davis King committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

using namespace std;
using namespace dlib;

// Here is the sinc function we will be trying to learn with rvm regression 
double sinc(double x)
{
    if (x == 0)
        return 1;
    return sin(x)/x;
}

int main()
{
    // Here we declare that our samples will be 1 dimensional column vectors.  
    typedef matrix<double,1,1> sample_type;

    // Now sample some points from the sinc() function
    sample_type m;
    std::vector<sample_type> samples;
    std::vector<double> labels;
    for (double x = -10; x <= 4; x += 1)
    {
        m(0) = x;
        samples.push_back(m);
        labels.push_back(sinc(x));
    }

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    // Now we are making a typedef for the kind of kernel we want to use.  I picked the
    // radial basis kernel because it only has one parameter and generally gives good
    // results without much fiddling.
    typedef radial_basis_kernel<sample_type> kernel_type;

    // Here we declare an instance of the rvm_regression_trainer object.  This is the
    // object that we will later use to do the training.
    rvm_regression_trainer<kernel_type> trainer;

    // Here we set the kernel we want to use for training.   The radial_basis_kernel 
    // has a parameter called gamma that we need to determine.  As a rule of thumb, a good 
    // gamma to try is 1.0/(mean squared distance between your sample points).  So 
    // below we are using a similar value.   Note also that using an inappropriately large
    // gamma will cause the RVM training algorithm to run extremely slowly.  What
    // "large" means is relative to how spread out your data is.  So it is important
    // to use a rule like this as a starting point for determining the gamma value
    // if you want to use the RVM.  It is also probably a good idea to normalize your
    // samples as shown in the rvm_ex.cpp example program.
    const double gamma = 2.0/compute_mean_squared_distance(samples);
    cout << "using gamma of " << gamma << endl;
    trainer.set_kernel(kernel_type(gamma));

64
65
66
67
68
    // One thing you can do to reduce the RVM training time is to make its
    // stopping epsilon bigger.  However, this might make the outputs less
    // reliable.  But sometimes it works out well.  0.001 is the default.
    trainer.set_epsilon(0.001);

Davis King's avatar
Davis King committed
69
70
71
72
73
74
75
76
77
78
79
    // now train a function based on our sample points
    decision_function<kernel_type> test = trainer.train(samples, labels);

    // now we output the value of the sinc function for a few test points as well as the 
    // value predicted by our regression.
    m(0) = 2.5; cout << sinc(m(0)) << "   " << test(m) << endl;
    m(0) = 0.1; cout << sinc(m(0)) << "   " << test(m) << endl;
    m(0) = -4;  cout << sinc(m(0)) << "   " << test(m) << endl;
    m(0) = 5.0; cout << sinc(m(0)) << "   " << test(m) << endl;

    // The output is as follows:
80
    //using gamma of 0.05
Davis King's avatar
Davis King committed
81
82
83
84
85
86
87
88
89
    //0.239389   0.240989
    //0.998334   0.999538
    //-0.189201   -0.188453
    //-0.191785   -0.226516


    // The first column is the true value of the sinc function and the second
    // column is the output from the rvm estimate.  

90
91
92
93


    // Another thing that is worth knowing is that just about everything in dlib is serializable.
    // So for example, you can save the test object to disk and recall it later like so:
94
    serialize("saved_function.dat") << test;
95

Davis King's avatar
Davis King committed
96
    // Now let's open that file back up and load the function object it contains.
97
    deserialize("saved_function.dat") >> test;
98

Davis King's avatar
Davis King committed
99
100
101
}