object_detector.h 19.1 KB
Newer Older
Davis King's avatar
Davis King committed
1
2
// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
3
4
#ifndef DLIB_OBJECT_DeTECTOR_Hh_
#define DLIB_OBJECT_DeTECTOR_Hh_
Davis King's avatar
Davis King committed
5
6
7
8

#include "object_detector_abstract.h"
#include "../geometry.h"
#include <vector>
9
#include "box_overlap_testing.h"
10
#include "full_object_detection.h"
Davis King's avatar
Davis King committed
11
12
13
14

namespace dlib
{

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// ----------------------------------------------------------------------------------------

    struct rect_detection
    {
        double detection_confidence;
        unsigned long weight_index;
        rectangle rect;

        bool operator<(const rect_detection& item) const { return detection_confidence < item.detection_confidence; }
    };

    struct full_detection
    {
        double detection_confidence;
        unsigned long weight_index;
        full_object_detection rect;

        bool operator<(const full_detection& item) const { return detection_confidence < item.detection_confidence; }
    };

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// ----------------------------------------------------------------------------------------

    template <typename image_scanner_type>
    struct processed_weight_vector
    {
        processed_weight_vector(){}

        typedef typename image_scanner_type::feature_vector_type feature_vector_type;

        void init (
            const image_scanner_type& 
        ) 
        /*!
            requires
                - w has already been assigned its value.  Note that the point of this
                  function is to allow an image scanner to overload the
                  processed_weight_vector template and provide some different kind of
                  object as the output of get_detect_argument().  For example, the
                  scan_fhog_pyramid object uses an overload that causes
                  get_detect_argument() to return the special fhog_filterbank object
                  instead of a feature_vector_type.  This avoids needing to construct the
                  fhog_filterbank during each call to detect and therefore speeds up
                  detection.
        !*/
        {}

        // return the first argument to image_scanner_type::detect()
        const feature_vector_type& get_detect_argument() const { return w; }

        feature_vector_type w;
    };

Davis King's avatar
Davis King committed
67
68
69
// ----------------------------------------------------------------------------------------

    template <
70
        typename image_scanner_type_
Davis King's avatar
Davis King committed
71
72
73
74
        >
    class object_detector
    {
    public:
75
        typedef image_scanner_type_ image_scanner_type;
Davis King's avatar
Davis King committed
76
77
        typedef typename image_scanner_type::feature_vector_type feature_vector_type;

Davis King's avatar
Davis King committed
78
79
80
81
82
83
84
85
86
        object_detector (
        );

        object_detector (
            const object_detector& item 
        );

        object_detector (
            const image_scanner_type& scanner_, 
87
            const test_box_overlap& overlap_tester_,
Davis King's avatar
Davis King committed
88
            const feature_vector_type& w_ 
Davis King's avatar
Davis King committed
89
90
        );

91
92
93
94
95
96
        object_detector (
            const image_scanner_type& scanner_, 
            const test_box_overlap& overlap_tester_,
            const std::vector<feature_vector_type>& w_ 
        );

97
98
99
100
        explicit object_detector (
            const std::vector<object_detector>& detectors
        );

101
102
103
        unsigned long num_detectors (
        ) const { return w.size(); }

Davis King's avatar
Davis King committed
104
        const feature_vector_type& get_w (
105
106
            unsigned long idx = 0
        ) const { return w[idx].w; }
107
108
109
110
        
        const processed_weight_vector<image_scanner_type>& get_processed_w (
            unsigned long idx = 0
        ) const { return w[idx]; }
111

112
        const test_box_overlap& get_overlap_tester (
113
114
115
116
117
        ) const;

        const image_scanner_type& get_scanner (
        ) const;

Davis King's avatar
Davis King committed
118
119
120
121
122
123
124
125
        object_detector& operator= (
            const object_detector& item 
        );

        template <
            typename image_type
            >
        std::vector<rectangle> operator() (
126
127
            const image_type& img,
            double adjust_threshold = 0
128
        );
Davis King's avatar
Davis King committed
129

130
131
132
133
134
        template <
            typename image_type
            >
        void operator() (
            const image_type& img,
135
            std::vector<std::pair<double, rectangle> >& final_dets,
136
            double adjust_threshold = 0
137
        );
138

139
140
141
142
143
144
145
146
147
        template <
            typename image_type
            >
        void operator() (
            const image_type& img,
            std::vector<std::pair<double, full_object_detection> >& final_dets,
            double adjust_threshold = 0
        );

148
149
150
151
152
153
154
155
156
        template <
            typename image_type
            >
        void operator() (
            const image_type& img,
            std::vector<full_object_detection>& final_dets,
            double adjust_threshold = 0
        );

157
158
159
160
        // These typedefs are here for backwards compatibility with previous versions of
        // dlib.
        typedef ::dlib::rect_detection rect_detection;
        typedef ::dlib::full_detection full_detection;
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

        template <
            typename image_type
            >
        void operator() (
            const image_type& img,
            std::vector<rect_detection>& final_dets,
            double adjust_threshold = 0
        );

        template <
            typename image_type
            >
        void operator() (
            const image_type& img,
            std::vector<full_detection>& final_dets,
            double adjust_threshold = 0
        );

180
        template <typename T>
Davis King's avatar
Davis King committed
181
        friend void serialize (
182
            const object_detector<T>& item,
Davis King's avatar
Davis King committed
183
184
185
            std::ostream& out
        );

186
        template <typename T>
Davis King's avatar
Davis King committed
187
        friend void deserialize (
188
            object_detector<T>& item,
Davis King's avatar
Davis King committed
189
190
191
192
193
194
            std::istream& in 
        );

    private:

        bool overlaps_any_box (
195
            const std::vector<rect_detection>& rects,
196
197
198
199
200
            const dlib::rectangle& rect
        ) const
        {
            for (unsigned long i = 0; i < rects.size(); ++i)
            {
201
                if (boxes_overlap(rects[i].rect, rect))
202
203
204
205
206
                    return true;
            }
            return false;
        }

207
        test_box_overlap boxes_overlap;
208
        std::vector<processed_weight_vector<image_scanner_type> > w;
209
        image_scanner_type scanner;
Davis King's avatar
Davis King committed
210
211
212
213
    };

// ----------------------------------------------------------------------------------------

214
    template <typename T>
Davis King's avatar
Davis King committed
215
    void serialize (
216
        const object_detector<T>& item,
Davis King's avatar
Davis King committed
217
218
219
        std::ostream& out
    )
    {
220
        int version = 2;
221
222
        serialize(version, out);

Davis King's avatar
Davis King committed
223
224
225
226
        T scanner;
        scanner.copy_configuration(item.scanner);
        serialize(scanner, out);
        serialize(item.boxes_overlap, out);
227
228
229
230
        // serialize all the weight vectors
        serialize(item.w.size(), out);
        for (unsigned long i = 0; i < item.w.size(); ++i)
            serialize(item.w[i].w, out);
Davis King's avatar
Davis King committed
231
232
233
234
    }

// ----------------------------------------------------------------------------------------

235
    template <typename T>
Davis King's avatar
Davis King committed
236
    void deserialize (
237
        object_detector<T>& item,
Davis King's avatar
Davis King committed
238
239
240
        std::istream& in 
    )
    {
241
242
        int version = 0;
        deserialize(version, in);
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        if (version == 1)
        {
            deserialize(item.scanner, in);
            item.w.resize(1);
            deserialize(item.w[0].w, in);
            item.w[0].init(item.scanner);
            deserialize(item.boxes_overlap, in);
        }
        else if (version == 2)
        {
            deserialize(item.scanner, in);
            deserialize(item.boxes_overlap, in);
            unsigned long num_detectors = 0;
            deserialize(num_detectors, in);
            item.w.resize(num_detectors);
            for (unsigned long i = 0; i < item.w.size(); ++i)
            {
                deserialize(item.w[i].w, in);
                item.w[i].init(item.scanner);
            }
        }
        else 
        {
266
            throw serialization_error("Unexpected version encountered while deserializing a dlib::object_detector object.");
267
        }
Davis King's avatar
Davis King committed
268
269
270
271
272
273
274
275
276
    }

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
//                      object_detector member functions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    template <
277
        typename image_scanner_type
Davis King's avatar
Davis King committed
278
        >
279
    object_detector<image_scanner_type>::
Davis King's avatar
Davis King committed
280
281
282
283
284
285
286
287
    object_detector (
    )
    {
    }

// ----------------------------------------------------------------------------------------

    template <
288
        typename image_scanner_type
Davis King's avatar
Davis King committed
289
        >
290
    object_detector<image_scanner_type>::
Davis King's avatar
Davis King committed
291
292
293
294
295
296
297
298
299
300
301
302
    object_detector (
        const object_detector& item 
    )
    {
        boxes_overlap = item.boxes_overlap;
        w = item.w;
        scanner.copy_configuration(item.scanner);
    }

// ----------------------------------------------------------------------------------------

    template <
303
        typename image_scanner_type
Davis King's avatar
Davis King committed
304
        >
305
    object_detector<image_scanner_type>::
Davis King's avatar
Davis King committed
306
307
    object_detector (
        const image_scanner_type& scanner_, 
308
        const test_box_overlap& overlap_tester,
Davis King's avatar
Davis King committed
309
        const feature_vector_type& w_ 
Davis King's avatar
Davis King committed
310
    ) :
311
        boxes_overlap(overlap_tester)
Davis King's avatar
Davis King committed
312
    {
313
314
        // make sure requires clause is not broken
        DLIB_ASSERT(scanner_.get_num_detection_templates() > 0 &&
Davis King's avatar
Davis King committed
315
                    w_.size() == scanner_.get_num_dimensions() + 1, 
316
317
318
319
320
321
322
323
            "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
            << "\n\t Invalid inputs were given to this function "
            << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
            << "\n\t w_.size():                     " << w_.size()
            << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions()
            << "\n\t this: " << this
            );

Davis King's avatar
Davis King committed
324
        scanner.copy_configuration(scanner_);
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        w.resize(1);
        w[0].w = w_;
        w[0].init(scanner);
    }

// ----------------------------------------------------------------------------------------

    template <
        typename image_scanner_type
        >
    object_detector<image_scanner_type>::
    object_detector (
        const image_scanner_type& scanner_, 
        const test_box_overlap& overlap_tester,
        const std::vector<feature_vector_type>& w_ 
    ) :
        boxes_overlap(overlap_tester)
    {
        // make sure requires clause is not broken
344
        DLIB_CASSERT(scanner_.get_num_detection_templates() > 0 && w_.size() > 0,
345
346
347
348
349
350
351
352
353
            "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
            << "\n\t Invalid inputs were given to this function "
            << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
            << "\n\t w_.size():                     " << w_.size()
            << "\n\t this: " << this
            );

        for (unsigned long i = 0; i < w_.size(); ++i)
        {
354
            DLIB_CASSERT(w_[i].size() == scanner_.get_num_dimensions() + 1, 
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
                "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
                << "\n\t w_["<<i<<"].size():                     " << w_[i].size()
                << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions()
                << "\n\t this: " << this
                );
        }

        scanner.copy_configuration(scanner_);
        w.resize(w_.size());
        for (unsigned long i = 0; i < w.size(); ++i)
        {
            w[i].w = w_[i];
            w[i].init(scanner);
        }
Davis King's avatar
Davis King committed
371
372
    }

373
374
375
376
377
378
379
380
381
382
// ----------------------------------------------------------------------------------------

    template <
        typename image_scanner_type
        >
    object_detector<image_scanner_type>::
    object_detector (
        const std::vector<object_detector>& detectors
    )
    {
383
        DLIB_CASSERT(detectors.size() != 0,
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
                "\t object_detector::object_detector(detectors)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t this: " << this
        );
        std::vector<feature_vector_type> weights;
        weights.reserve(detectors.size());
        for (unsigned long i = 0; i < detectors.size(); ++i)
        {
            for (unsigned long j = 0; j < detectors[i].num_detectors(); ++j)
                weights.push_back(detectors[i].get_w(j));
        }

        *this = object_detector(detectors[0].get_scanner(), detectors[0].get_overlap_tester(), weights);
    }

Davis King's avatar
Davis King committed
399
400
401
// ----------------------------------------------------------------------------------------

    template <
402
        typename image_scanner_type
Davis King's avatar
Davis King committed
403
        >
404
    object_detector<image_scanner_type>& object_detector<image_scanner_type>::
Davis King's avatar
Davis King committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    operator= (
        const object_detector& item 
    )
    {
        if (this == &item)
            return *this;

        boxes_overlap = item.boxes_overlap;
        w = item.w;
        scanner.copy_configuration(item.scanner);
        return *this;
    }

// ----------------------------------------------------------------------------------------

    template <
421
        typename image_scanner_type
Davis King's avatar
Davis King committed
422
423
424
425
        >
    template <
        typename image_type
        >
426
    void object_detector<image_scanner_type>::
Davis King's avatar
Davis King committed
427
    operator() (
428
        const image_type& img,
429
        std::vector<rect_detection>& final_dets,
430
        double adjust_threshold
431
    ) 
Davis King's avatar
Davis King committed
432
    {
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        scanner.load(img);
        std::vector<std::pair<double, rectangle> > dets;
        std::vector<rect_detection> dets_accum;
        for (unsigned long i = 0; i < w.size(); ++i)
        {
            const double thresh = w[i].w(scanner.get_num_dimensions());
            scanner.detect(w[i].get_detect_argument(), dets, thresh + adjust_threshold);
            for (unsigned long j = 0; j < dets.size(); ++j)
            {
                rect_detection temp;
                temp.detection_confidence = dets[j].first-thresh;
                temp.weight_index = i;
                temp.rect = dets[j].second;
                dets_accum.push_back(temp);
            }
        }

        // Do non-max suppression
        final_dets.clear();
452
        if (w.size() > 1)
453
            std::sort(dets_accum.rbegin(), dets_accum.rend());
454
        for (unsigned long i = 0; i < dets_accum.size(); ++i)
Davis King's avatar
Davis King committed
455
        {
456
457
            if (overlaps_any_box(final_dets, dets_accum[i].rect))
                continue;
Davis King's avatar
Davis King committed
458

459
460
461
            final_dets.push_back(dets_accum[i]);
        }
    }
Davis King's avatar
Davis King committed
462

463
// ----------------------------------------------------------------------------------------
Davis King's avatar
Davis King committed
464

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    template <
        typename image_scanner_type
        >
    template <
        typename image_type
        >
    void object_detector<image_scanner_type>::
    operator() (
        const image_type& img,
        std::vector<full_detection>& final_dets,
        double adjust_threshold 
    )
    {
        std::vector<rect_detection> dets;
        (*this)(img,dets,adjust_threshold);

        final_dets.resize(dets.size());

        // convert all the rectangle detections into full_object_detections.
        for (unsigned long i = 0; i < dets.size(); ++i)
        {
            final_dets[i].detection_confidence = dets[i].detection_confidence;
            final_dets[i].weight_index = dets[i].weight_index;
            final_dets[i].rect = scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w);
Davis King's avatar
Davis King committed
489
        }
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    }

// ----------------------------------------------------------------------------------------

    template <
        typename image_scanner_type
        >
    template <
        typename image_type
        >
    std::vector<rectangle> object_detector<image_scanner_type>::
    operator() (
        const image_type& img,
        double adjust_threshold
    ) 
    {
        std::vector<rect_detection> dets;
        (*this)(img,dets,adjust_threshold);

        std::vector<rectangle> final_dets(dets.size());
        for (unsigned long i = 0; i < dets.size(); ++i)
            final_dets[i] = dets[i].rect;
Davis King's avatar
Davis King committed
512
513
514
515

        return final_dets;
    }

516
517
518
// ----------------------------------------------------------------------------------------

    template <
519
        typename image_scanner_type
520
521
522
523
        >
    template <
        typename image_type
        >
524
    void object_detector<image_scanner_type>::
525
526
    operator() (
        const image_type& img,
527
528
        std::vector<std::pair<double, rectangle> >& final_dets,
        double adjust_threshold
529
    ) 
530
    {
531
532
        std::vector<rect_detection> dets;
        (*this)(img,dets,adjust_threshold);
533

534
535
536
        final_dets.resize(dets.size());
        for (unsigned long i = 0; i < dets.size(); ++i)
            final_dets[i] = std::make_pair(dets[i].detection_confidence,dets[i].rect);
537
538
    }

539
540
541
// ----------------------------------------------------------------------------------------

    template <
542
        typename image_scanner_type
543
544
545
546
        >
    template <
        typename image_type
        >
547
    void object_detector<image_scanner_type>::
548
549
550
551
552
553
    operator() (
        const image_type& img,
        std::vector<std::pair<double, full_object_detection> >& final_dets,
        double adjust_threshold
    ) 
    {
554
555
        std::vector<rect_detection> dets;
        (*this)(img,dets,adjust_threshold);
556
557

        final_dets.clear();
558
        final_dets.reserve(dets.size());
559
560

        // convert all the rectangle detections into full_object_detections.
561
        for (unsigned long i = 0; i < dets.size(); ++i)
562
        {
563
564
            final_dets.push_back(std::make_pair(dets[i].detection_confidence, 
                                                scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w)));
565
566
567
        }
    }

568
569
570
// ----------------------------------------------------------------------------------------

    template <
571
        typename image_scanner_type
572
573
574
575
        >
    template <
        typename image_type
        >
576
    void object_detector<image_scanner_type>::
577
578
579
580
581
582
    operator() (
        const image_type& img,
        std::vector<full_object_detection>& final_dets,
        double adjust_threshold
    ) 
    {
583
584
        std::vector<rect_detection> dets;
        (*this)(img,dets,adjust_threshold);
585
586

        final_dets.clear();
587
        final_dets.reserve(dets.size());
588
589

        // convert all the rectangle detections into full_object_detections.
590
        for (unsigned long i = 0; i < dets.size(); ++i)
591
        {
592
            final_dets.push_back(scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w));
593
594
595
        }
    }

596
597
598
// ----------------------------------------------------------------------------------------

    template <
599
        typename image_scanner_type
600
        >
601
    const test_box_overlap& object_detector<image_scanner_type>::
602
603
604
605
606
607
608
609
610
    get_overlap_tester (
    ) const
    {
        return boxes_overlap;
    }

// ----------------------------------------------------------------------------------------

    template <
611
        typename image_scanner_type
612
        >
613
    const image_scanner_type& object_detector<image_scanner_type>::
614
615
616
617
618
619
    get_scanner (
    ) const
    {
        return scanner;
    }

Davis King's avatar
Davis King committed
620
621
622
623
// ----------------------------------------------------------------------------------------

}

624
#endif // DLIB_OBJECT_DeTECTOR_Hh_
Davis King's avatar
Davis King committed
625
626