structural_object_detection_trainer.h 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__
#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__

#include "structural_object_detection_trainer_abstract.h"
#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_object_detection_problem.h"
#include "../image_processing/object_detector.h"
#include "../image_processing/box_overlap_testing.h"
12
#include "../image_processing/full_object_detection.h"
13
14
15
16
17
18
19
20


namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
21
        typename image_scanner_type
22
23
24
25
26
27
28
        >
    class structural_object_detection_trainer : noncopyable
    {

    public:
        typedef double scalar_type;
        typedef default_memory_manager mem_manager_type;
29
        typedef object_detector<image_scanner_type> trained_function_type;
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


        explicit structural_object_detection_trainer (
            const image_scanner_type& scanner_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(scanner_.get_num_detection_templates() > 0,
                "\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)"
                << "\n\t You can't have zero detection templates"
                << "\n\t this: " << this
                );

            C = 1;
            verbose = false;
            eps = 0.3;
            num_threads = 2;
47
            max_cache_size = 5;
48
            match_eps = 0.5;
49
50
51
52
            loss_per_missed_target = 1;
            loss_per_false_alarm = 1;

            scanner.copy_configuration(scanner_);
53

54
            auto_overlap_tester = true;
55
56
        }

57
58
59
60
61
62
        const image_scanner_type& get_scanner (
        ) const
        {
            return scanner;
        }

63
64
65
66
        bool auto_set_overlap_tester (
        ) const 
        { 
            return auto_overlap_tester; 
67
68
69
        }

        void set_overlap_tester (
70
            const test_box_overlap& tester
71
72
73
        )
        {
            overlap_tester = tester;
74
            auto_overlap_tester = false;
75
76
        }

77
        test_box_overlap get_overlap_tester (
78
79
        ) const
        {
80
81
            // make sure requires clause is not broken
            DLIB_ASSERT(auto_set_overlap_tester() == false,
82
                "\t test_box_overlap structural_object_detection_trainer::get_overlap_tester()"
83
84
85
86
                << "\n\t You can't call this function if the overlap tester is generated dynamically."
                << "\n\t this: " << this
                );

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            return overlap_tester;
        }

        void set_num_threads (
            unsigned long num
        )
        {
            num_threads = num;
        }

        unsigned long get_num_threads (
        ) const
        {
            return num_threads;
        }

        void set_epsilon (
            scalar_type eps_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(eps_ > 0,
                "\t void structural_object_detection_trainer::set_epsilon()"
                << "\n\t eps_ must be greater than 0"
                << "\n\t eps_: " << eps_ 
                << "\n\t this: " << this
                );

            eps = eps_;
        }

118
        scalar_type get_epsilon (
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        ) const { return eps; }

        void set_max_cache_size (
            unsigned long max_size
        )
        {
            max_cache_size = max_size;
        }

        unsigned long get_max_cache_size (
        ) const
        {
            return max_cache_size; 
        }

        void be_verbose (
        )
        {
            verbose = true;
        }

        void be_quiet (
        )
        {
            verbose = false;
        }

        void set_oca (
            const oca& item
        )
        {
            solver = item;
        }

        const oca get_oca (
        ) const
        {
            return solver;
        }

        void set_c (
            scalar_type C_ 
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C_ > 0,
                "\t void structural_object_detection_trainer::set_c()"
                << "\n\t C_ must be greater than 0"
                << "\n\t C_:    " << C_ 
                << "\n\t this: " << this
                );

            C = C_;
        }

174
        scalar_type get_c (
175
176
177
178
179
        ) const
        {
            return C;
        }

180
        void set_match_eps (
181
182
183
184
185
            double eps
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(0 < eps && eps < 1, 
186
                "\t void structural_object_detection_trainer::set_match_eps(eps)"
187
188
189
190
191
                << "\n\t Invalid inputs were given to this function "
                << "\n\t eps:  " << eps 
                << "\n\t this: " << this
                );

192
            match_eps = eps;
193
194
        }

195
        double get_match_eps (
196
197
        ) const
        {
198
            return match_eps;
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        }

        double get_loss_per_missed_target (
        ) const
        {
            return loss_per_missed_target;
        }

        void set_loss_per_missed_target (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_missed_target = loss;
        }

        double get_loss_per_false_alarm (
        ) const
        {
            return loss_per_false_alarm;
        }

        void set_loss_per_false_alarm (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_false_alarm = loss;
        }

        template <
            typename image_array_type
            >
        const trained_function_type train (
            const image_array_type& images,
248
            const std::vector<std::vector<full_object_detection> >& truth_object_detections
249
250
        ) const
        {
251
#ifdef ENABLE_ASSERTS
252
            // make sure requires clause is not broken
253
254
            DLIB_ASSERT(is_learning_problem(images,truth_object_detections) == true,
                "\t trained_function_type structural_object_detection_trainer::train()"
255
256
                << "\n\t invalid inputs were given to this function"
                << "\n\t images.size():      " << images.size()
257
258
                << "\n\t truth_object_detections.size(): " << truth_object_detections.size()
                << "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections)
259
                );
260
261
262
263
            for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
            {
                for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
                {
264
                    DLIB_ASSERT(truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() &&
Davis King's avatar
Davis King committed
265
                                all_parts_in_rect(truth_object_detections[i][j]) == true,
266
267
                        "\t trained_function_type structural_object_detection_trainer::train()"
                        << "\n\t invalid inputs were given to this function"
268
269
                        << "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts():                " << 
                            truth_object_detections[i][j].num_parts()
270
271
                        << "\n\t get_scanner().get_num_movable_components_per_detection_template(): " << 
                            get_scanner().get_num_movable_components_per_detection_template()
Davis King's avatar
Davis King committed
272
                        << "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
273
274
275
276
                    );
                }
            }
#endif
277

278
            structural_svm_object_detection_problem<image_scanner_type,image_array_type > 
279
                svm_prob(scanner, overlap_tester, auto_overlap_tester, images, truth_object_detections, num_threads);
280
281
282
283
284
285
286

            if (verbose)
                svm_prob.be_verbose();

            svm_prob.set_c(C);
            svm_prob.set_epsilon(eps);
            svm_prob.set_max_cache_size(max_cache_size);
287
            svm_prob.set_match_eps(match_eps);
288
289
290
            svm_prob.set_loss_per_missed_target(loss_per_missed_target);
            svm_prob.set_loss_per_false_alarm(loss_per_false_alarm);
            matrix<double,0,1> w;
Davis King's avatar
Davis King committed
291
292

            // Run the optimizer to find the optimal w.
293
294
            solver(svm_prob,w);

Davis King's avatar
Davis King committed
295
            // report the results of the training.
296
            return object_detector<image_scanner_type>(scanner, svm_prob.get_overlap_tester(), w);
297
298
        }

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        template <
            typename image_array_type
            >
        const trained_function_type train (
            const image_array_type& images,
            const std::vector<std::vector<rectangle> >& truth_object_detections
        ) const
        {
            std::vector<std::vector<full_object_detection> > truth_dets(truth_object_detections.size());
            for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
            {
                for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
                {
                    truth_dets[i].push_back(full_object_detection(truth_object_detections[i][j]));
                }
            }

            return train(images, truth_dets);
        }
318
319
320
321

    private:

        image_scanner_type scanner;
322
        test_box_overlap overlap_tester;
323
324
325
326

        double C;
        oca solver;
        double eps;
327
        double match_eps;
328
329
330
331
332
        bool verbose;
        unsigned long num_threads;
        unsigned long max_cache_size;
        double loss_per_missed_target;
        double loss_per_false_alarm;
333
        bool auto_overlap_tester;
334
335
336
337
338
339
340
341
342
343

    }; 

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__