shape_predictor.h 11.4 KB
Newer Older
1
2
// Copyright (C) 2014  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
3
4
#ifndef DLIB_SHAPE_PREDICTOR_H__
#define DLIB_SHAPE_PREDICTOR_H__
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

#include "dlib/string.h"
#include "dlib/geometry.h"
#include "dlib/data_io/load_image_dataset.h"
#include "dlib/image_processing.h"

using namespace std;

namespace dlib
{

// ----------------------------------------------------------------------------------------

    struct shape_predictor_training_options
    {
        shape_predictor_training_options()
        {
            be_verbose = false;
            cascade_depth = 10;
            tree_depth = 4;
            num_trees_per_cascade_level = 500;
            nu = 0.1;
            oversampling_amount = 20;
28
            oversampling_translation_jitter = 0;
29
            feature_pool_size = 400;
30
            lambda_param = 0.1;
31
32
33
            num_test_splits = 20;
            feature_pool_region_padding = 0;
            random_seed = "";
34
            num_threads = 0;
35
            landmark_relative_padding_mode = true;
36
37
38
39
40
41
42
43
        }

        bool be_verbose;
        unsigned long cascade_depth;
        unsigned long tree_depth;
        unsigned long num_trees_per_cascade_level;
        double nu;
        unsigned long oversampling_amount;
44
        double oversampling_translation_jitter;
45
        unsigned long feature_pool_size;
46
        double lambda_param;
47
48
49
        unsigned long num_test_splits;
        double feature_pool_region_padding;
        std::string random_seed;
50
        bool landmark_relative_padding_mode;
51
52
53

        // not serialized
        unsigned long num_threads;
54
55
    };

56
57
58
59
60
61
62
    inline void serialize (
        const shape_predictor_training_options& item,
        std::ostream& out
    )
    {
        try
        {
63
            serialize("shape_predictor_training_options_v2", out);
64
65
66
67
68
69
            serialize(item.be_verbose,out);
            serialize(item.cascade_depth,out);
            serialize(item.tree_depth,out);
            serialize(item.num_trees_per_cascade_level,out);
            serialize(item.nu,out);
            serialize(item.oversampling_amount,out);
70
            serialize(item.oversampling_translation_jitter,out);
71
72
73
74
75
            serialize(item.feature_pool_size,out);
            serialize(item.lambda_param,out);
            serialize(item.num_test_splits,out);
            serialize(item.feature_pool_region_padding,out);
            serialize(item.random_seed,out);
76
            serialize(item.landmark_relative_padding_mode,out);
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        }
        catch (serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while serializing an object of type shape_predictor_training_options");
        }
    }

    inline void deserialize (
        shape_predictor_training_options& item,
        std::istream& in
    )
    {
        try
        {
91
            check_serialized_version("shape_predictor_training_options_v2", in);
92
93
94
95
96
97
            deserialize(item.be_verbose,in);
            deserialize(item.cascade_depth,in);
            deserialize(item.tree_depth,in);
            deserialize(item.num_trees_per_cascade_level,in);
            deserialize(item.nu,in);
            deserialize(item.oversampling_amount,in);
98
            deserialize(item.oversampling_translation_jitter,in);
99
100
101
102
103
            deserialize(item.feature_pool_size,in);
            deserialize(item.lambda_param,in);
            deserialize(item.num_test_splits,in);
            deserialize(item.feature_pool_region_padding,in);
            deserialize(item.random_seed,in);
104
            deserialize(item.landmark_relative_padding_mode,in);
105
106
107
108
109
110
111
        }
        catch (serialization_error& e)
        {
            throw serialization_error(e.info + "\n   while deserializing an object of type shape_predictor_training_options");
        }
    }

Davis King's avatar
Davis King committed
112
    inline string print_shape_predictor_training_options(const shape_predictor_training_options& o)
113
114
115
    {
        std::ostringstream sout;
        sout << "shape_predictor_training_options("
Davis King's avatar
Davis King committed
116
117
118
119
120
121
122
123
124
125
126
127
            << "be_verbose=" << o.be_verbose << ", "
            << "cascade_depth=" << o.cascade_depth << ", "
            << "tree_depth=" << o.tree_depth << ", "
            << "num_trees_per_cascade_level=" << o.num_trees_per_cascade_level << ", "
            << "nu=" << o.nu << ", "
            << "oversampling_amount=" << o.oversampling_amount << ", "
            << "oversampling_translation_jitter=" << o.oversampling_translation_jitter << ", "
            << "feature_pool_size=" << o.feature_pool_size << ", "
            << "lambda_param=" << o.lambda_param << ", "
            << "num_test_splits=" << o.num_test_splits << ", "
            << "feature_pool_region_padding=" << o.feature_pool_region_padding << ", "
            << "random_seed=" << o.random_seed << ", "
128
            << "num_threads=" << o.num_threads
129
            << "landmark_relative_padding_mode=" << o.landmark_relative_padding_mode
130
131
132
133
        << ")";
        return sout.str();
    }

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// ----------------------------------------------------------------------------------------

    namespace impl
    {
        inline bool contains_any_detections (
            const std::vector<std::vector<full_object_detection> >& detections
        )
        {
            for (unsigned long i = 0; i < detections.size(); ++i)
            {
                if (detections[i].size() != 0)
                    return true;
            }
            return false;
        }
    }

// ----------------------------------------------------------------------------------------

    template <typename image_array>
154
    inline shape_predictor train_shape_predictor_on_images (
155
156
157
158
159
        image_array& images,
        std::vector<std::vector<full_object_detection> >& detections,
        const shape_predictor_training_options& options
    )
    {
160
161
        if (options.lambda_param <= 0)
            throw error("Invalid lambda_param value given to train_shape_predictor(), lambda_param must be > 0.");
162
163
        if (!(0 < options.nu && options.nu <= 1))
            throw error("Invalid nu value given to train_shape_predictor(). It is required that 0 < nu <= 1.");
Davis King's avatar
Davis King committed
164
165
        if (options.feature_pool_region_padding <= -0.5)
            throw error("Invalid feature_pool_region_padding value given to train_shape_predictor(), feature_pool_region_padding must be > -0.5.");
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

        if (images.size() != detections.size())
            throw error("The list of images must have the same length as the list of detections.");

        if (!impl::contains_any_detections(detections))
            throw error("Error, the training dataset does not have any labeled object detections in it.");

        shape_predictor_trainer trainer;

        trainer.set_cascade_depth(options.cascade_depth);
        trainer.set_tree_depth(options.tree_depth);
        trainer.set_num_trees_per_cascade_level(options.num_trees_per_cascade_level);
        trainer.set_nu(options.nu);
        trainer.set_random_seed(options.random_seed);
        trainer.set_oversampling_amount(options.oversampling_amount);
181
        trainer.set_oversampling_translation_jitter(options.oversampling_translation_jitter);
182
183
        trainer.set_feature_pool_size(options.feature_pool_size);
        trainer.set_feature_pool_region_padding(options.feature_pool_region_padding);
184
        trainer.set_lambda(options.lambda_param);
185
        trainer.set_num_test_splits(options.num_test_splits);
186
        trainer.set_num_threads(options.num_threads);
187
188
189
190
        if (options.landmark_relative_padding_mode)
            trainer.set_padding_mode(shape_predictor_trainer::landmark_relative);
        else
            trainer.set_padding_mode(shape_predictor_trainer::bounding_box_relative);
191
192
193
194
195
196
197
198
199

        if (options.be_verbose)
        {
            std::cout << "Training with cascade depth: " << options.cascade_depth << std::endl;
            std::cout << "Training with tree depth: " << options.tree_depth << std::endl;
            std::cout << "Training with " << options.num_trees_per_cascade_level << " trees per cascade level."<< std::endl;
            std::cout << "Training with nu: " << options.nu << std::endl;
            std::cout << "Training with random seed: " << options.random_seed << std::endl;
            std::cout << "Training with oversampling amount: " << options.oversampling_amount << std::endl;
200
            std::cout << "Training with oversampling translation jitter: " << options.oversampling_translation_jitter << std::endl;
201
            std::cout << "Training with landmark_relative_padding_mode: " << options.landmark_relative_padding_mode << std::endl;
202
203
            std::cout << "Training with feature pool size: " << options.feature_pool_size << std::endl;
            std::cout << "Training with feature pool region padding: " << options.feature_pool_region_padding << std::endl;
204
            std::cout << "Training with " << options.num_threads << " threads." << std::endl;
205
            std::cout << "Training with lambda_param: " << options.lambda_param << std::endl;
206
207
208
209
210
211
            std::cout << "Training with " << options.num_test_splits << " split tests."<< std::endl;
            trainer.be_verbose();
        }

        shape_predictor predictor = trainer.train(images, detections);

212
        return predictor;
213
214
215
216
217
218
219
220
    }

    inline void train_shape_predictor (
        const std::string& dataset_filename,
        const std::string& predictor_output_filename,
        const shape_predictor_training_options& options
    )
    {
221
        dlib::array<array2d<unsigned char> > images;
222
223
224
        std::vector<std::vector<full_object_detection> > objects;
        load_image_dataset(images, objects, dataset_filename);

225
226
        shape_predictor predictor = train_shape_predictor_on_images(images, objects, options);

227
        serialize(predictor_output_filename) << predictor;
228
229
230

        if (options.be_verbose)
            std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl;
231
232
233
234
235
236
237
238
239
    }

// ----------------------------------------------------------------------------------------

    template <typename image_array>
    inline double test_shape_predictor_with_images (
            image_array& images,
            std::vector<std::vector<full_object_detection> >& detections,
            std::vector<std::vector<double> >& scales,
240
            const shape_predictor& predictor
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    )
    {
        if (images.size() != detections.size())
            throw error("The list of images must have the same length as the list of detections.");
        if (scales.size() > 0  && scales.size() != images.size())
            throw error("The list of scales must have the same length as the list of detections.");

        if (scales.size() > 0)
            return test_shape_predictor(predictor, images, detections, scales);
        else
            return test_shape_predictor(predictor, images, detections);
    }

    inline double test_shape_predictor_py (
        const std::string& dataset_filename,
        const std::string& predictor_filename
    )
    {
259
        // Load the images, no scales can be provided
260
        dlib::array<array2d<unsigned char> > images;
261
262
263
264
265
        // This interface cannot take the scales parameter.
        std::vector<std::vector<double> > scales;
        std::vector<std::vector<full_object_detection> > objects;
        load_image_dataset(images, objects, dataset_filename);

266
267
        // Load the shape predictor
        shape_predictor predictor;
268
        deserialize(predictor_filename) >> predictor;
269
270

        return test_shape_predictor_with_images(images, objects, scales, predictor);
271
272
273
274
275
276
    }

// ----------------------------------------------------------------------------------------

}

277
#endif // DLIB_SHAPE_PREDICTOR_H__
278