svm_c_trainer.cpp 12 KB
Newer Older
1
2
// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
3

Davis King's avatar
Davis King committed
4
#include "opaque_types.h"
5
#include <dlib/python.h>
6
#include "testing_results.h"
7
#include <dlib/matrix.h>
8
#include <dlib/svm_threaded.h>
9
10
11
12
13

using namespace dlib;
using namespace std;

typedef matrix<double,0,1> sample_type; 
Davis King's avatar
Davis King committed
14
typedef std::vector<std::pair<unsigned long,double> > sparse_vect;
15

Davis King's avatar
Davis King committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
template <typename trainer_type>
typename trainer_type::trained_function_type train (
    const trainer_type& trainer,
    const std::vector<typename trainer_type::sample_type>& samples,
    const std::vector<double>& labels
)
{
    pyassert(is_binary_classification_problem(samples,labels), "Invalid inputs");
    return trainer.train(samples, labels);
}

template <typename trainer_type>
void set_epsilon ( trainer_type& trainer, double eps)
{
    pyassert(eps > 0, "epsilon must be > 0");
    trainer.set_epsilon(eps);
}

template <typename trainer_type>
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }


template <typename trainer_type>
void set_cache_size ( trainer_type& trainer, long cache_size)
{
    pyassert(cache_size > 0, "cache size must be > 0");
    trainer.set_cache_size(cache_size);
}
44

Davis King's avatar
Davis King committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
template <typename trainer_type>
long get_cache_size ( const trainer_type& trainer) { return trainer.get_cache_size(); }


template <typename trainer_type>
void set_c ( trainer_type& trainer, double C)
{
    pyassert(C > 0, "C must be > 0");
    trainer.set_c(C);
}

template <typename trainer_type>
void set_c_class1 ( trainer_type& trainer, double C)
{
    pyassert(C > 0, "C must be > 0");
    trainer.set_c_class1(C);
}

template <typename trainer_type>
void set_c_class2 ( trainer_type& trainer, double C)
{
    pyassert(C > 0, "C must be > 0");
    trainer.set_c_class2(C);
}
69

Davis King's avatar
Davis King committed
70
71
72
73
74
75
template <typename trainer_type>
double get_c_class1 ( const trainer_type& trainer) { return trainer.get_c_class1(); }
template <typename trainer_type>
double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); }

template <typename trainer_type>
76
py::class_<trainer_type> setup_trainer_eps (
77
    py::module& m,
Davis King's avatar
Davis King committed
78
79
80
    const std::string& name
)
{
81
    return py::class_<trainer_type>(m, name.c_str())
Davis King's avatar
Davis King committed
82
        .def("train", train<trainer_type>)
83
        .def_property("epsilon", get_epsilon<trainer_type>, set_epsilon<trainer_type>);
Davis King's avatar
Davis King committed
84
85
86
}

template <typename trainer_type>
87
py::class_<trainer_type> setup_trainer_eps_c (
88
    py::module& m,
Davis King's avatar
Davis King committed
89
90
91
    const std::string& name
)
{
92
93
94
95
    return setup_trainer_eps<trainer_type>(m, name)
        .def("set_c", set_c<trainer_type>)
        .def_property("c_class1", get_c_class1<trainer_type>, set_c_class1<trainer_type>)
        .def_property("c_class2", get_c_class2<trainer_type>, set_c_class2<trainer_type>);
Davis King's avatar
Davis King committed
96
97
}

98
99
100
101
template <typename trainer_type>
py::class_<trainer_type> setup_trainer_eps_c_cache (
    py::module& m,
    const std::string& name
Davis King's avatar
Davis King committed
102
103
)
{
104
105
    return setup_trainer_eps_c<trainer_type>(m, name)
        .def_property("cache_size", get_cache_size<trainer_type>, set_cache_size<trainer_type>);
Davis King's avatar
Davis King committed
106
107
}

108
109
110
template <typename trainer_type>
void set_gamma (
    trainer_type& trainer,
Davis King's avatar
Davis King committed
111
    double gamma
112
113
)
{
Davis King's avatar
Davis King committed
114
    pyassert(gamma > 0, "gamma must be > 0");
115
    trainer.set_kernel(typename trainer_type::kernel_type(gamma));
Davis King's avatar
Davis King committed
116
}
117

118
119
120
template <typename trainer_type>
double get_gamma (
    const trainer_type& trainer
Davis King's avatar
Davis King committed
121
122
123
)
{
    return trainer.get_kernel().gamma;
124
125
}

126
127
128
129
130
131
132
133
134
// ----------------------------------------------------------------------------------------

template <
    typename trainer_type
    >
const binary_test _cross_validate_trainer (
    const trainer_type& trainer,
    const std::vector<typename trainer_type::sample_type>& x,
    const std::vector<double>& y,
135
    const unsigned long folds
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
)
{
    pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set.");
    pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given.");
    return cross_validate_trainer(trainer, x, y, folds);
}

template <
    typename trainer_type
    >
const binary_test _cross_validate_trainer_t (
    const trainer_type& trainer,
    const std::vector<typename trainer_type::sample_type>& x,
    const std::vector<double>& y,
    const unsigned long folds,
    const unsigned long num_threads
)
{
    pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set.");
    pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given.");
    pyassert(1 < num_threads, "The number of threads specified must not be zero.");
    return cross_validate_trainer_threaded(trainer, x, y, folds, num_threads);
}
159

Davis King's avatar
Davis King committed
160
161
// ----------------------------------------------------------------------------------------

162
void bind_svm_c_trainer(py::module& m)
163
{
164
    namespace py = pybind11;
165
166

    // svm_c
Davis King's avatar
Davis King committed
167
168
    {
        typedef svm_c_trainer<radial_basis_kernel<sample_type> > T;
169
170
171
        setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_radial_basis")
            .def(py::init())
            .def_property("gamma", get_gamma<T>, set_gamma<T>);
172
173
174
175
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
Davis King's avatar
Davis King committed
176
177
178
179
    }

    {
        typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
180
181
182
        setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_radial_basis")
            .def(py::init())
            .def_property("gamma", get_gamma<T>, set_gamma<T>);
183
184
185
186
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
Davis King's avatar
Davis King committed
187
188
189
190
    }

    {
        typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T;
191
192
        setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_histogram_intersection")
            .def(py::init());
193
194
195
196
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
Davis King's avatar
Davis King committed
197
198
199
200
    }

    {
        typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
201
202
        setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_histogram_intersection")
            .def(py::init());
203
204
205
206
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
Davis King's avatar
Davis King committed
207
208
    }

209
    // svm_c_linear
Davis King's avatar
Davis King committed
210
211
    {
        typedef svm_c_linear_trainer<linear_kernel<sample_type> > T;
212
        setup_trainer_eps_c<T>(m, "svm_c_trainer_linear")
213
214
215
216
217
            .def(py::init())
            .def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations)
            .def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1)
            .def_property("learns_nonnegative_weights", &T::learns_nonnegative_weights, &T::set_learns_nonnegative_weights)
            .def_property_readonly("has_prior", &T::has_prior)
218
            .def("set_prior", &T::set_prior)
Davis King's avatar
Davis King committed
219
220
221
            .def("be_verbose", &T::be_verbose)
            .def("be_quiet", &T::be_quiet);

222
223
224
225
        m.def("cross_validate_trainer", _cross_validate_trainer<T>,
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
Davis King's avatar
Davis King committed
226
227
228
229
    }

    {
        typedef svm_c_linear_trainer<sparse_linear_kernel<sparse_vect> > T;
230
231
        setup_trainer_eps_c<T>(m, "svm_c_trainer_sparse_linear")
            .def(py::init())
232
233
234
235
            .def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations)
            .def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1)
            .def_property("learns_nonnegative_weights", &T::learns_nonnegative_weights, &T::set_learns_nonnegative_weights)
            .def_property_readonly("has_prior", &T::has_prior)
236
            .def("set_prior", &T::set_prior)
Davis King's avatar
Davis King committed
237
238
239
            .def("be_verbose", &T::be_verbose)
            .def("be_quiet", &T::be_quiet);

240
241
242
243
        m.def("cross_validate_trainer", _cross_validate_trainer<T>,
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
Davis King's avatar
Davis King committed
244
    }
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

    // rvm
    {
        typedef rvm_trainer<radial_basis_kernel<sample_type> > T;
        setup_trainer_eps<T>(m, "rvm_trainer_radial_basis")
            .def(py::init())
            .def_property("gamma", get_gamma<T>, set_gamma<T>);
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
    }

    {
        typedef rvm_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
        setup_trainer_eps<T>(m, "rvm_trainer_sparse_radial_basis")
            .def(py::init())
            .def_property("gamma", get_gamma<T>, set_gamma<T>);
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
    }

    {
        typedef rvm_trainer<histogram_intersection_kernel<sample_type> > T;
        setup_trainer_eps<T>(m, "rvm_trainer_histogram_intersection")
            .def(py::init());
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
    }

    {
        typedef rvm_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
        setup_trainer_eps<T>(m, "rvm_trainer_sparse_histogram_intersection")
            .def(py::init());
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
    }

    // rvm linear
    {
        typedef rvm_trainer<linear_kernel<sample_type> > T;
        setup_trainer_eps<T>(m, "rvm_trainer_linear")
            .def(py::init());
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
    }

    {
        typedef rvm_trainer<sparse_linear_kernel<sparse_vect> > T;
        setup_trainer_eps<T>(m, "rvm_trainer_sparse_linear")
            .def(py::init());
        m.def("cross_validate_trainer", _cross_validate_trainer<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
        m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, 
            py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
    }
309
310
311
}