main.cpp 6.55 KB
Newer Older
1
// The contents of this file are in the public domain. 
2
3
4
5
6
// See LICENSE_FOR_EXAMPLE_PROGRAMS.txt (in trunk/examples)
// Authors:
//   Gregory Sharp
//   Davis King

7
8
9
10
/*
    This is a command line program that can try different regression 
    algorithms on a libsvm-formatted data set.
*/
11
12
13

#include "regression.h"

14
15
16
17
#include <iostream>
#include <map>
#include <vector>

18

Davis King's avatar
Davis King committed
19
#include "dlib/cmd_line_parser.h"
20
21
22
23
24
#include "dlib/data_io.h"
#include "dlib/svm.h"

using namespace dlib;

25
// ----------------------------------------------------------------------------------------
26
27

static void
Davis King's avatar
Davis King committed
28
parse_args (command_line_parser& parser, int argc, char* argv[])
29
30
{
    try {
31
        // Algorithm-independent options
32
        parser.add_option ("a",
33
                           "Choose the learning algorithm: {krls,krr,mlp,svr}.",1);
34
35
36
        parser.add_option ("h","Display this help message.");
        parser.add_option ("help","Display this help message.");
        parser.add_option ("k",
37
                           "Learning kernel (for krls,krr,svr methods): {lin,rbk}.",1);
38
39
        parser.add_option ("in","A libsvm-formatted file to test.",1);
        parser.add_option ("normalize",
40
                           "Normalize the sample inputs to zero-mean unit variance.");
41
        parser.add_option ("train-best",
42
                           "Train and save a network using best parameters", 1);
43

Davis King's avatar
Davis King committed
44
        parser.set_group_name("Algorithm Specific Options");
45
        // Algorithm-specific options
46
        parser.add_option ("rbk-gamma",
47
                           "Width of radial basis kernels: {float}.",1);
48
        parser.add_option ("krls-tolerance",
49
                           "Numerical tolerance of krls linear dependency test: {float}.",1);
50
        parser.add_option ("mlp-hidden-units",
51
                           "Number of hidden units in mlp: {integer}.",1);
52
        parser.add_option ("mlp-num-iterations",
53
                           "Number of epochs to train the mlp: {integer}.",1);
54
        parser.add_option ("svr-c",
55
56
                           "SVR regularization parameter \"C\": "
                           "{float}.",1);
57
        parser.add_option ("svr-epsilon-insensitivity",
58
59
                           "SVR fitting tolerance parameter: "
                           "{float}.",1);
60
61
        parser.add_option ("verbose", "Use verbose trainers");

62
        // Parse the command line arguments
63
64
        parser.parse(argc,argv);

65
        // Check that options aren't given multiple times
66
67
68
        const char* one_time_opts[] = {"a", "h", "help", "in"};
        parser.check_one_time_options(one_time_opts);

69
70
71
72
73
74
75
76
77
78
79
        const char* valid_kernels[] = {"lin", "rbk"};
        const char* valid_algs[]    = {"krls", "krr", "mlp", "svr"};
        parser.check_option_arg_range("a", valid_algs);
        parser.check_option_arg_range("k", valid_kernels);
        parser.check_option_arg_range("rbk-gamma", 1e-200, 1e200);
        parser.check_option_arg_range("krls-tolerance", 1e-200, 1e200);
        parser.check_option_arg_range("mlp-hidden-units", 1, 10000000);
        parser.check_option_arg_range("mlp-num-iterations", 1, 10000000);
        parser.check_option_arg_range("svr-c", 1e-200, 1e200);
        parser.check_option_arg_range("svr-epsilon-insensitivity", 1e-200, 1e200);

80
81
        // Check if the -h option was given
        if (parser.option("h") || parser.option("help")) {
Davis King's avatar
Davis King committed
82
            std::cout << "Usage: mltool [-a algorithm] --in input_file\n";
83
            parser.print_options(std::cout);
84
85
            std::cout << std::endl;
            exit (0);
86
87
        }

88
        // Check that an input file was given
89
        if (!parser.option("in")) {
90
91
92
93
94
95
            std::cout 
                << "Error in command line:\n"
                << "You must specify an input file with the --in option.\n"
                << "\nTry the -h option for more information\n";
            exit (0);
        }
96
97
98
    }
    catch (std::exception& e) {
        // Catch cmd_line_parse_error exceptions and print usage message.
99
100
        std::cout << e.what() << std::endl;
        exit (1);
101
102
    }
    catch (...) {
103
        std::cout << "Some error occurred" << std::endl;
104
105
106
    }
}

107
// ----------------------------------------------------------------------------------------
108
109
110
111

int 
main (int argc, char* argv[])
{
Davis King's avatar
Davis King committed
112
    command_line_parser parser;
113
114
115
116
117
118
119
120

    parse_args(parser, argc, argv);


    std::vector<sparse_sample_type> sparse_samples;
    std::vector<double> labels;

    load_libsvm_formatted_data (
Davis King's avatar
Davis King committed
121
        parser.option("in").argument(), 
122
123
        sparse_samples, 
        labels
124
125
126
    );

    if (sparse_samples.size() < 1) {
127
128
129
130
        std::cout 
            << "Sorry, I couldn't find any samples in your data set.\n"
            << "Aborting the operation.\n";
        exit (0);
131
132
133
134
135
136
137
138
    }

    std::vector<dense_sample_type> dense_samples;
    dense_samples = sparse_to_dense (sparse_samples);

    /* GCS FIX: The sparse_to_dense converter adds an extra column, 
       because libsvm files are indexed starting with "1". */
    std::cout 
139
140
141
142
143
144
        << "Loaded " << sparse_samples.size() << " samples"
        << std::endl
        << "Each sample has size " << sparse_samples[0].size() 
        << std::endl
        << "Each dense sample has size " << dense_samples[0].size() 
        << std::endl;
145
146
147

    // Normalize inputs to N(0,1)
    if (parser.option ("normalize")) {
148
149
150
151
152
        vector_normalizer<dense_sample_type> normalizer;
        normalizer.train (dense_samples);
        for (unsigned long i = 0; i < dense_samples.size(); ++i) {
            dense_samples[i] = normalizer (dense_samples[i]);
        }
153
154
155
156
157
    }

    // Randomize the order of the samples, labels
    randomize_samples (dense_samples, labels);

Davis King's avatar
Davis King committed
158
    const command_line_parser::option_type& option_alg = parser.option("a");
159
    if (!option_alg) {
160
161
162
        // Do KRR if user didn't specify an algorithm
        std::cout << "No algorithm specified, default to KRR\n";
        krr_test (parser, dense_samples, labels);
163
164
    }
    else if (option_alg.argument() == "krls") {
165
        krls_test (parser, dense_samples, labels);
166
167
    }
    else if (option_alg.argument() == "krr") {
168
        krr_test (parser, dense_samples, labels);
169
170
    }
    else if (option_alg.argument() == "mlp") {
171
        mlp_test (parser, dense_samples, labels);
172
173
    }
    else if (option_alg.argument() == "svr") {
174
        svr_test (parser, dense_samples, labels);
175
176
    }
    else {
177
178
179
180
181
        fprintf (stderr, 
                 "Error, algorithm \"%s\" is unknown.\n"
                 "Please use -h to see the command line options\n",
                 option_alg.argument().c_str());
        exit (-1);
182
183
    }
}
184
185
186

// ----------------------------------------------------------------------------------------