svm_threaded.h 8.28 KB
Newer Older
1
// Copyright (C) 2008  Davis E. King (davis@dlib.net)
2
3
4
5
6
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SVm_THREADED_
#define DLIB_SVm_THREADED_

#include <cmath>
7
#include <iostream>
8
9
#include <limits>
#include <sstream>
10
11
12
13
#include <vector>

#include "svm_threaded_abstract.h"
#include "svm.h"
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include "../matrix.h"
#include "../algs.h"
#include "../serialize.h"
#include "function.h"
#include "kernel.h"
#include "../threads.h"
#include "../pipe.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    namespace cvtti_helpers
    {
29
        template <typename trainer_type, typename in_sample_vector_type>
30
31
32
33
34
35
36
37
        struct job
        {
            typedef typename trainer_type::scalar_type scalar_type;
            typedef typename trainer_type::sample_type sample_type;
            typedef typename trainer_type::mem_manager_type mem_manager_type;
            typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
            typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;

38
39
            job() : x(0) {}

40
            trainer_type trainer;
41
            matrix<long,0,1> x_test, x_train;
42
            scalar_vector_type y_test, y_train;
43
            const in_sample_vector_type* x;
44
45
        };

46
        struct task  
47
        {
48
49
            template <
                typename trainer_type,
50
51
                typename matrix_type,
                typename in_sample_vector_type
52
53
                >
            void operator()(
54
                job<trainer_type,in_sample_vector_type>& j,
55
56
                matrix_type& result
            )
57
            {
58
                try
59
                {
60
                    result = test_binary_decision_function(j.trainer.train(rowm(*j.x,j.x_train), j.y_train), rowm(*j.x,j.x_test), j.y_test);
61
62
63

                    // Do this just to make j release it's memory since people might run threaded cross validation
                    // on very large datasets.  Every bit of freed memory helps out.
64
                    j = job<trainer_type,in_sample_vector_type>();
65
                }
66
                catch (invalid_nu_error&)
67
                {
68
69
70
                    // If this is a svm_nu_trainer then we might get this exception if the nu is
                    // invalid.  In this case just return a cross validation score of 0.
                    result = 0;
71
                }
72
73
74
75
76
                catch (std::bad_alloc&)
                {
                    std::cerr << "\nstd::bad_alloc thrown while running cross_validate_trainer_threaded().  Not enough memory.\n" << std::endl;
                    throw;
                }
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            }
        };
    }

    template <
        typename trainer_type,
        typename in_sample_vector_type,
        typename in_scalar_vector_type
        >
    const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> 
    cross_validate_trainer_threaded_impl (
        const trainer_type& trainer,
        const in_sample_vector_type& x,
        const in_scalar_vector_type& y,
        const long folds,
        const long num_threads
    )
    {
        using namespace dlib::cvtti_helpers;
        typedef typename trainer_type::scalar_type scalar_type;
        typedef typename trainer_type::mem_manager_type mem_manager_type;

        // make sure requires clause is not broken
        DLIB_ASSERT(is_binary_classification_problem(x,y) == true &&
101
                    1 < folds && folds <= std::min(sum(y>0),sum(y<0)) &&
102
103
104
                    num_threads > 0,
            "\tmatrix cross_validate_trainer()"
            << "\n\t invalid inputs were given to this function"
105
            << "\n\t std::min(sum(y>0),sum(y<0)): " << std::min(sum(y>0),sum(y<0))
106
107
108
109
110
111
            << "\n\t folds:  " << folds 
            << "\n\t num_threads:  " << num_threads 
            << "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
            );


112
113
114
115
        task mytask;
        thread_pool tp(num_threads);


116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        // count the number of positive and negative examples
        long num_pos = 0;
        long num_neg = 0;
        for (long r = 0; r < y.nr(); ++r)
        {
            if (y(r) == +1.0)
                ++num_pos;
            else
                ++num_neg;
        }

        // figure out how many positive and negative examples we will have in each fold
        const long num_pos_test_samples = num_pos/folds; 
        const long num_pos_train_samples = num_pos - num_pos_test_samples; 
        const long num_neg_test_samples = num_neg/folds; 
        const long num_neg_train_samples = num_neg - num_neg_test_samples; 


        long pos_idx = 0;
        long neg_idx = 0;



139
        std::vector<future<job<trainer_type,in_sample_vector_type> > > jobs(folds);
140
141
        std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds);

142
143
144

        for (long i = 0; i < folds; ++i)
        {
145
            job<trainer_type,in_sample_vector_type>& j = jobs[i].get();
146

147
            j.x = &x;
148
149
150
151
152
153
154
155
156
157
158
159
160
            j.x_test.set_size (num_pos_test_samples  + num_neg_test_samples);
            j.y_test.set_size (num_pos_test_samples  + num_neg_test_samples);
            j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
            j.y_train.set_size(num_pos_train_samples + num_neg_train_samples);
            j.trainer = trainer;

            long cur = 0;

            // load up our positive test samples
            while (cur < num_pos_test_samples)
            {
                if (y(pos_idx) == +1.0)
                {
161
                    j.x_test(cur) = pos_idx;
162
163
164
165
166
167
168
169
170
171
172
                    j.y_test(cur) = +1.0;
                    ++cur;
                }
                pos_idx = (pos_idx+1)%x.nr();
            }

            // load up our negative test samples
            while (cur < j.x_test.nr())
            {
                if (y(neg_idx) == -1.0)
                {
173
                    j.x_test(cur) = neg_idx;
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                    j.y_test(cur) = -1.0;
                    ++cur;
                }
                neg_idx = (neg_idx+1)%x.nr();
            }

            // load the training data from the data following whatever we loaded
            // as the testing data
            long train_pos_idx = pos_idx;
            long train_neg_idx = neg_idx;
            cur = 0;

            // load up our positive train samples
            while (cur < num_pos_train_samples)
            {
                if (y(train_pos_idx) == +1.0)
                {
191
                    j.x_train(cur) = train_pos_idx;
192
193
194
195
196
197
198
199
200
201
202
                    j.y_train(cur) = +1.0;
                    ++cur;
                }
                train_pos_idx = (train_pos_idx+1)%x.nr();
            }

            // load up our negative train samples
            while (cur < j.x_train.nr())
            {
                if (y(train_neg_idx) == -1.0)
                {
203
                    j.x_train(cur) = train_neg_idx;
204
205
206
207
208
209
                    j.y_train(cur) = -1.0;
                    ++cur;
                }
                train_neg_idx = (train_neg_idx+1)%x.nr();
            }

210
211
            // finally spawn a task to process this job
            tp.add_task(mytask, jobs[i], results[i]);
212
213
214
215
216
217

        } // for (long i = 0; i < folds; ++i)

        matrix<scalar_type, 1, 2, mem_manager_type> res;
        set_all_elements(res,0);

218
        // now compute the total results
219
220
        for (long i = 0; i < folds; ++i)
        {
221
            res += results[i].get();
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        }

        return res/(scalar_type)folds;
    }

    template <
        typename trainer_type,
        typename in_sample_vector_type,
        typename in_scalar_vector_type
        >
    const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> 
    cross_validate_trainer_threaded (
        const trainer_type& trainer,
        const in_sample_vector_type& x,
        const in_scalar_vector_type& y,
        const long folds,
        const long num_threads
    )
    {
        return cross_validate_trainer_threaded_impl(trainer,
242
243
                                           mat(x),
                                           mat(y),
244
245
246
247
248
249
250
251
252
253
254
                                           folds,
                                           num_threads);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_SVm_THREADED_