"vscode:/vscode.git/clone" did not exist on "e7d7d2705c090145bff51dc140b21013e60e9c15"
main.cpp 6.46 KB
Newer Older
1
// The contents of this file are in the public domain. 
2
3
4
5
6
// See LICENSE_FOR_EXAMPLE_PROGRAMS.txt (in trunk/examples)
// Authors:
//   Gregory Sharp
//   Davis King

7
8
9
10
/*
    This is a command line program that can try different regression 
    algorithms on a libsvm-formatted data set.
*/
11
12
13

#include "regression.h"

14
15
16
17
#include <iostream>
#include <map>
#include <vector>

18

19
20
21
22
23
#include "dlib/data_io.h"
#include "dlib/svm.h"

using namespace dlib;

24
// ----------------------------------------------------------------------------------------
25
26
27
28
29

static void
parse_args (clp& parser, int argc, char* argv[])
{
    try {
30
        // Algorithm-independent options
31
        parser.add_option ("a",
32
                           "Choose the learning algorithm: {krls,krr,mlp,svr}.",1);
33
34
35
        parser.add_option ("h","Display this help message.");
        parser.add_option ("help","Display this help message.");
        parser.add_option ("k",
36
                           "Learning kernel (for krls,krr,svr methods): {lin,rbk}.",1);
37
38
        parser.add_option ("in","A libsvm-formatted file to test.",1);
        parser.add_option ("normalize",
39
                           "Normalize the sample inputs to zero-mean unit variance.");
40
        parser.add_option ("train-best",
41
                           "Train and save a network using best parameters", 1);
42

43
        // Algorithm-specific options
44
        parser.add_option ("rbk-gamma",
45
                           "Width of radial basis kernels: {float}.",1);
46
        parser.add_option ("krls-tolerance",
47
                           "Numerical tolerance of krls linear dependency test: {float}.",1);
48
        parser.add_option ("mlp-hidden-units",
49
                           "Number of hidden units in mlp: {integer}.",1);
50
        parser.add_option ("mlp-num-iterations",
51
                           "Number of epochs to train the mlp: {integer}.",1);
52
        parser.add_option ("svr-c",
53
54
                           "SVR regularization parameter \"C\": "
                           "{float}.",1);
55
        parser.add_option ("svr-epsilon-insensitivity",
56
57
                           "SVR fitting tolerance parameter: "
                           "{float}.",1);
58
59
        parser.add_option ("verbose", "Use verbose trainers");

60
        // Parse the command line arguments
61
62
        parser.parse(argc,argv);

63
        // Check that options aren't given multiple times
64
65
66
        const char* one_time_opts[] = {"a", "h", "help", "in"};
        parser.check_one_time_options(one_time_opts);

67
68
69
70
71
72
73
74
75
76
77
        const char* valid_kernels[] = {"lin", "rbk"};
        const char* valid_algs[]    = {"krls", "krr", "mlp", "svr"};
        parser.check_option_arg_range("a", valid_algs);
        parser.check_option_arg_range("k", valid_kernels);
        parser.check_option_arg_range("rbk-gamma", 1e-200, 1e200);
        parser.check_option_arg_range("krls-tolerance", 1e-200, 1e200);
        parser.check_option_arg_range("mlp-hidden-units", 1, 10000000);
        parser.check_option_arg_range("mlp-num-iterations", 1, 10000000);
        parser.check_option_arg_range("svr-c", 1e-200, 1e200);
        parser.check_option_arg_range("svr-epsilon-insensitivity", 1e-200, 1e200);

78
79
        // Check if the -h option was given
        if (parser.option("h") || parser.option("help")) {
80
            std::cout << "Usage: dlib_test [-a algorithm] --in input_file\n";
81
            parser.print_options(std::cout);
82
83
            std::cout << std::endl;
            exit (0);
84
85
        }

86
        // Check that an input file was given
87
        if (!parser.option("in")) {
88
89
90
91
92
93
            std::cout 
                << "Error in command line:\n"
                << "You must specify an input file with the --in option.\n"
                << "\nTry the -h option for more information\n";
            exit (0);
        }
94
95
96
    }
    catch (std::exception& e) {
        // Catch cmd_line_parse_error exceptions and print usage message.
97
98
        std::cout << e.what() << std::endl;
        exit (1);
99
100
    }
    catch (...) {
101
        std::cout << "Some error occurred" << std::endl;
102
103
104
    }
}

105
// ----------------------------------------------------------------------------------------
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

int 
main (int argc, char* argv[])
{
    clp parser;

    parse_args(parser, argc, argv);

    const clp::option_type& option_alg = parser.option("a");
    const clp::option_type& option_in = parser.option("in");

    std::vector<sparse_sample_type> sparse_samples;
    std::vector<double> labels;

    load_libsvm_formatted_data (
121
122
123
        option_in.argument(), 
        sparse_samples, 
        labels
124
125
126
    );

    if (sparse_samples.size() < 1) {
127
128
129
130
        std::cout 
            << "Sorry, I couldn't find any samples in your data set.\n"
            << "Aborting the operation.\n";
        exit (0);
131
132
133
134
135
136
137
138
    }

    std::vector<dense_sample_type> dense_samples;
    dense_samples = sparse_to_dense (sparse_samples);

    /* GCS FIX: The sparse_to_dense converter adds an extra column, 
       because libsvm files are indexed starting with "1". */
    std::cout 
139
140
141
142
143
144
        << "Loaded " << sparse_samples.size() << " samples"
        << std::endl
        << "Each sample has size " << sparse_samples[0].size() 
        << std::endl
        << "Each dense sample has size " << dense_samples[0].size() 
        << std::endl;
145
146
147

    // Normalize inputs to N(0,1)
    if (parser.option ("normalize")) {
148
149
150
151
152
        vector_normalizer<dense_sample_type> normalizer;
        normalizer.train (dense_samples);
        for (unsigned long i = 0; i < dense_samples.size(); ++i) {
            dense_samples[i] = normalizer (dense_samples[i]);
        }
153
154
155
156
157
158
    }

    // Randomize the order of the samples, labels
    randomize_samples (dense_samples, labels);

    if (!option_alg) {
159
160
161
        // Do KRR if user didn't specify an algorithm
        std::cout << "No algorithm specified, default to KRR\n";
        krr_test (parser, dense_samples, labels);
162
163
    }
    else if (option_alg.argument() == "krls") {
164
        krls_test (parser, dense_samples, labels);
165
166
    }
    else if (option_alg.argument() == "krr") {
167
        krr_test (parser, dense_samples, labels);
168
169
    }
    else if (option_alg.argument() == "mlp") {
170
        mlp_test (parser, dense_samples, labels);
171
172
    }
    else if (option_alg.argument() == "svr") {
173
        svr_test (parser, dense_samples, labels);
174
175
    }
    else {
176
177
178
179
180
        fprintf (stderr, 
                 "Error, algorithm \"%s\" is unknown.\n"
                 "Please use -h to see the command line options\n",
                 option_alg.argument().c_str());
        exit (-1);
181
182
    }
}
183
184
185

// ----------------------------------------------------------------------------------------