main.cpp 6.49 KB
Newer Older
1
// The contents of this file are in the public domain. 
2
3
4
5
6
// See LICENSE_FOR_EXAMPLE_PROGRAMS.txt (in trunk/examples)
// Authors:
//   Gregory Sharp
//   Davis King

7
8
9
10
/*
    This is a command line program that can try different regression 
    algorithms on a libsvm-formatted data set.
*/
11
12
13

#include "regression.h"

14
15
16
17
#include <iostream>
#include <map>
#include <vector>

18

Davis King's avatar
Davis King committed
19
#include "dlib/cmd_line_parser.h"
20
21
22
23
24
#include "dlib/data_io.h"
#include "dlib/svm.h"

using namespace dlib;

25
// ----------------------------------------------------------------------------------------
26
27

static void
Davis King's avatar
Davis King committed
28
parse_args (command_line_parser& parser, int argc, char* argv[])
29
30
{
    try {
31
        // Algorithm-independent options
32
        parser.add_option ("a",
33
                           "Choose the learning algorithm: {krls,krr,mlp,svr}.",1);
34
35
36
        parser.add_option ("h","Display this help message.");
        parser.add_option ("help","Display this help message.");
        parser.add_option ("k",
37
                           "Learning kernel (for krls,krr,svr methods): {lin,rbk}.",1);
38
39
        parser.add_option ("in","A libsvm-formatted file to test.",1);
        parser.add_option ("normalize",
40
                           "Normalize the sample inputs to zero-mean unit variance.");
41
        parser.add_option ("train-best",
42
                           "Train and save a network using best parameters", 1);
43

44
        // Algorithm-specific options
45
        parser.add_option ("rbk-gamma",
46
                           "Width of radial basis kernels: {float}.",1);
47
        parser.add_option ("krls-tolerance",
48
                           "Numerical tolerance of krls linear dependency test: {float}.",1);
49
        parser.add_option ("mlp-hidden-units",
50
                           "Number of hidden units in mlp: {integer}.",1);
51
        parser.add_option ("mlp-num-iterations",
52
                           "Number of epochs to train the mlp: {integer}.",1);
53
        parser.add_option ("svr-c",
54
55
                           "SVR regularization parameter \"C\": "
                           "{float}.",1);
56
        parser.add_option ("svr-epsilon-insensitivity",
57
58
                           "SVR fitting tolerance parameter: "
                           "{float}.",1);
59
60
        parser.add_option ("verbose", "Use verbose trainers");

61
        // Parse the command line arguments
62
63
        parser.parse(argc,argv);

64
        // Check that options aren't given multiple times
65
66
67
        const char* one_time_opts[] = {"a", "h", "help", "in"};
        parser.check_one_time_options(one_time_opts);

68
69
70
71
72
73
74
75
76
77
78
        const char* valid_kernels[] = {"lin", "rbk"};
        const char* valid_algs[]    = {"krls", "krr", "mlp", "svr"};
        parser.check_option_arg_range("a", valid_algs);
        parser.check_option_arg_range("k", valid_kernels);
        parser.check_option_arg_range("rbk-gamma", 1e-200, 1e200);
        parser.check_option_arg_range("krls-tolerance", 1e-200, 1e200);
        parser.check_option_arg_range("mlp-hidden-units", 1, 10000000);
        parser.check_option_arg_range("mlp-num-iterations", 1, 10000000);
        parser.check_option_arg_range("svr-c", 1e-200, 1e200);
        parser.check_option_arg_range("svr-epsilon-insensitivity", 1e-200, 1e200);

79
80
        // Check if the -h option was given
        if (parser.option("h") || parser.option("help")) {
81
            std::cout << "Usage: dlib_test [-a algorithm] --in input_file\n";
82
            parser.print_options(std::cout);
83
84
            std::cout << std::endl;
            exit (0);
85
86
        }

87
        // Check that an input file was given
88
        if (!parser.option("in")) {
89
90
91
92
93
94
            std::cout 
                << "Error in command line:\n"
                << "You must specify an input file with the --in option.\n"
                << "\nTry the -h option for more information\n";
            exit (0);
        }
95
96
97
    }
    catch (std::exception& e) {
        // Catch cmd_line_parse_error exceptions and print usage message.
98
99
        std::cout << e.what() << std::endl;
        exit (1);
100
101
    }
    catch (...) {
102
        std::cout << "Some error occurred" << std::endl;
103
104
105
    }
}

106
// ----------------------------------------------------------------------------------------
107
108
109
110

int 
main (int argc, char* argv[])
{
Davis King's avatar
Davis King committed
111
    command_line_parser parser;
112
113
114
115
116
117
118
119

    parse_args(parser, argc, argv);


    std::vector<sparse_sample_type> sparse_samples;
    std::vector<double> labels;

    load_libsvm_formatted_data (
Davis King's avatar
Davis King committed
120
        parser.option("in").argument(), 
121
122
        sparse_samples, 
        labels
123
124
125
    );

    if (sparse_samples.size() < 1) {
126
127
128
129
        std::cout 
            << "Sorry, I couldn't find any samples in your data set.\n"
            << "Aborting the operation.\n";
        exit (0);
130
131
132
133
134
135
136
137
    }

    std::vector<dense_sample_type> dense_samples;
    dense_samples = sparse_to_dense (sparse_samples);

    /* GCS FIX: The sparse_to_dense converter adds an extra column, 
       because libsvm files are indexed starting with "1". */
    std::cout 
138
139
140
141
142
143
        << "Loaded " << sparse_samples.size() << " samples"
        << std::endl
        << "Each sample has size " << sparse_samples[0].size() 
        << std::endl
        << "Each dense sample has size " << dense_samples[0].size() 
        << std::endl;
144
145
146

    // Normalize inputs to N(0,1)
    if (parser.option ("normalize")) {
147
148
149
150
151
        vector_normalizer<dense_sample_type> normalizer;
        normalizer.train (dense_samples);
        for (unsigned long i = 0; i < dense_samples.size(); ++i) {
            dense_samples[i] = normalizer (dense_samples[i]);
        }
152
153
154
155
156
    }

    // Randomize the order of the samples, labels
    randomize_samples (dense_samples, labels);

Davis King's avatar
Davis King committed
157
    const command_line_parser::option_type& option_alg = parser.option("a");
158
    if (!option_alg) {
159
160
161
        // Do KRR if user didn't specify an algorithm
        std::cout << "No algorithm specified, default to KRR\n";
        krr_test (parser, dense_samples, labels);
162
163
    }
    else if (option_alg.argument() == "krls") {
164
        krls_test (parser, dense_samples, labels);
165
166
    }
    else if (option_alg.argument() == "krr") {
167
        krr_test (parser, dense_samples, labels);
168
169
    }
    else if (option_alg.argument() == "mlp") {
170
        mlp_test (parser, dense_samples, labels);
171
172
    }
    else if (option_alg.argument() == "svr") {
173
        svr_test (parser, dense_samples, labels);
174
175
    }
    else {
176
177
178
179
180
        fprintf (stderr, 
                 "Error, algorithm \"%s\" is unknown.\n"
                 "Please use -h to see the command line options\n",
                 option_alg.argument().c_str());
        exit (-1);
181
182
    }
}
183
184
185

// ----------------------------------------------------------------------------------------