shape_predictor.cpp 14.9 KB
Newer Older
1
2
3
// Copyright (C) 2014  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

Davis King's avatar
Davis King committed
4
#include "opaque_types.h"
5
6
7
8
9
10
11
12
#include <dlib/python.h>
#include <dlib/geometry.h>
#include <dlib/image_processing.h>
#include "shape_predictor.h"
#include "conversion.h"

using namespace dlib;
using namespace std;
13
14

namespace py = pybind11;
15
16
17
18
19

// ----------------------------------------------------------------------------------------

full_object_detection run_predictor (
        shape_predictor& predictor,
20
21
        py::object img,
        py::object rect
22
23
)
{
24
    rectangle box = rect.cast<rectangle>();
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    if (is_gray_python_image(img))
    {
        return predictor(numpy_gray_image(img), box);
    }
    else if (is_rgb_python_image(img))
    {
        return predictor(numpy_rgb_image(img), box);
    }
    else
    {
        throw dlib::error("Unsupported image type, must be 8bit gray or RGB image.");
    }
}

39
40
41
42
43
44
void save_shape_predictor(const shape_predictor& predictor, const std::string& predictor_output_filename)
{
    std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary);
    serialize(predictor, fout);
}

45
46
47
48
49
50
51
52
53
54
// ----------------------------------------------------------------------------------------

rectangle full_obj_det_get_rect (const full_object_detection& detection)
{ return detection.get_rect(); }

unsigned long full_obj_det_num_parts (const full_object_detection& detection)
{ return detection.num_parts(); }

point full_obj_det_part (const full_object_detection& detection, const unsigned long idx)
{
55
    if (idx >= detection.num_parts())
56
57
    {
        PyErr_SetString(PyExc_IndexError, "Index out of range");
58
        throw py::error_already_set();
59
60
61
62
63
64
65
66
67
68
69
70
71
    }
    return detection.part(idx);
}

std::vector<point> full_obj_det_parts (const full_object_detection& detection)
{
    const unsigned long num_parts = detection.num_parts();
    std::vector<point> parts(num_parts);
    for (unsigned long j = 0; j < num_parts; ++j)
        parts[j] = detection.part(j);
    return parts;
}

72
std::shared_ptr<full_object_detection> full_obj_det_init(py::object& pyrect, py::object& pyparts)
73
{
74
    const unsigned long num_parts = py::len(pyparts);
75
    std::vector<point> parts(num_parts);
76
77
    rectangle rect = pyrect.cast<rectangle>();
    py::iterator parts_it = pyparts.begin();
78

79
80
81
82
    for (unsigned long j = 0;
         parts_it != pyparts.end();
         ++j, ++parts_it)
        parts[j] = parts_it->cast<point>();
83

84
    return std::make_shared<full_object_detection>(rect, parts);
85
86
87
88
}

// ----------------------------------------------------------------------------------------

89
inline shape_predictor train_shape_predictor_on_images_py (
90
91
        const py::list& pyimages,
        const py::list& pydetections,
92
93
94
        const shape_predictor_training_options& options
)
{
95
96
    const unsigned long num_images = py::len(pyimages);
    if (num_images != py::len(pydetections))
97
98
99
        throw dlib::error("The length of the detections list must match the length of the images list.");

    std::vector<std::vector<full_object_detection> > detections(num_images);
100
    dlib::array<array2d<unsigned char> > images(num_images);
101
102
    images_and_nested_params_to_dlib(pyimages, pydetections, images, detections);

103
    return train_shape_predictor_on_images(images, detections, options);
104
105
106
107
}


inline double test_shape_predictor_with_images_py (
108
109
110
        const py::list& pyimages,
        const py::list& pydetections,
        const py::list& pyscales,
111
        const shape_predictor& predictor
112
113
)
{
114
115
116
    const unsigned long num_images = py::len(pyimages);
    const unsigned long num_scales = py::len(pyscales);
    if (num_images != py::len(pydetections))
117
118
119
120
121
122
123
124
125
        throw dlib::error("The length of the detections list must match the length of the images list.");

    if (num_scales > 0 && num_scales != num_images)
        throw dlib::error("The length of the scales list must match the length of the detections list.");

    std::vector<std::vector<full_object_detection> > detections(num_images);
    std::vector<std::vector<double> > scales;
    if (num_scales > 0)
        scales.resize(num_scales);
126
    dlib::array<array2d<unsigned char> > images(num_images);
127

128
    // Now copy the data into dlib based objects so we can call the testing routine.
129
130
    for (unsigned long i = 0; i < num_images; ++i)
    {
131
132
133
134
135
        const unsigned long num_boxes = py::len(pydetections[i]);
        for (py::iterator det_it = pydetections[i].begin();
             det_it != pydetections[i].end();
             ++det_it)
          detections[i].push_back(det_it->cast<full_object_detection>());
136
137
138
139

        pyimage_to_dlib_image(pyimages[i], images[i]);
        if (num_scales > 0)
        {
140
            if (num_boxes != py::len(pyscales[i]))
141
                throw dlib::error("The length of the scales list must match the length of the detections list.");
142
143
144
145
            for (py::iterator scale_it = pyscales[i].begin();
                 scale_it != pyscales[i].end();
                 ++scale_it)
                scales[i].push_back(scale_it->cast<double>());
146
147
148
        }
    }

149
    return test_shape_predictor_with_images(images, detections, scales, predictor);
150
151
152
}

inline double test_shape_predictor_with_images_no_scales_py (
153
154
        const py::list& pyimages,
        const py::list& pydetections,
155
        const shape_predictor& predictor
156
157
)
{
158
    py::list pyscales;
159
    return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor);
160
161
162
163
}

// ----------------------------------------------------------------------------------------

164
void bind_shape_predictors(py::module &m)
165
166
167
{
    {
    typedef full_object_detection type;
168
    py::class_<type, std::shared_ptr<type>>(m, "full_object_detection",
169
170
    "This object represents the location of an object in an image along with the \
    positions of each of its constituent parts.")
171
        .def(py::init(&full_obj_det_init),
172
173
174
"requires \n\
    - rect: dlib rectangle \n\
    - parts: list of dlib points")
175
176
177
        .def_property_readonly("rect", &full_obj_det_get_rect, "Bounding box from the underlying detector. Parts can be outside box if appropriate.")
        .def_property_readonly("num_parts", &full_obj_det_num_parts, "The number of parts of the object.")
        .def("part", &full_obj_det_part, py::arg("idx"), "A single part of the object as a dlib point.")
178
        .def("parts", &full_obj_det_parts, "A vector of dlib points representing all of the parts.")
179
        .def(py::pickle(&getstate<type>, &setstate<type>));
180
181
182
    }
    {
    typedef shape_predictor_training_options type;
183
    py::class_<type>(m, "shape_predictor_training_options",
184
        "This object is a container for the options to the train_shape_predictor() routine.")
185
186
        .def(py::init())
        .def_readwrite("be_verbose", &type::be_verbose,
187
                      "If true, train_shape_predictor() will print out a lot of information to stdout while training.")
188
        .def_readwrite("cascade_depth", &type::cascade_depth,
189
                      "The number of cascades created to train the model with.")
190
        .def_readwrite("tree_depth", &type::tree_depth,
191
                      "The depth of the trees used in each cascade. There are pow(2, get_tree_depth()) leaves in each tree")
192
        .def_readwrite("num_trees_per_cascade_level", &type::num_trees_per_cascade_level,
193
                      "The number of trees created for each cascade.")
194
        .def_readwrite("nu", &type::nu,
195
196
                      "The regularization parameter.  Larger values of this parameter \
                       will cause the algorithm to fit the training data better but may also \
197
                       cause overfitting.  The value must be in the range (0, 1].")
198
        .def_readwrite("oversampling_amount", &type::oversampling_amount,
199
                      "The number of randomly selected initial starting points sampled for each training example")
200
        .def_readwrite("feature_pool_size", &type::feature_pool_size,
201
                      "Number of pixels used to generate features for the random trees.")
202
        .def_readwrite("lambda_param", &type::lambda_param,
203
                      "Controls how tight the feature sampling should be. Lower values enforce closer features.")
204
        .def_readwrite("num_test_splits", &type::num_test_splits,
205
                      "Number of split features at each node to sample. The one that gives the best split is chosen.")
206
        .def_readwrite("feature_pool_region_padding", &type::feature_pool_region_padding,
207
208
                      "Size of region within which to sample features for the feature pool, \
                      e.g a padding of 0.5 would cause the algorithm to sample pixels from a box that was 2x2 pixels")
209
        .def_readwrite("random_seed", &type::random_seed,
210
                      "The random seed used by the internal random number generator")
211
212
        .def_readwrite("num_threads", &type::num_threads,
                        "Use this many threads/CPU cores for training.")
213
        .def("__str__", &::print_shape_predictor_training_options)
214
        .def(py::pickle(&getstate<type>, &setstate<type>));
215
216
217
    }
    {
    typedef shape_predictor type;
218
    py::class_<type, std::shared_ptr<type>>(m, "shape_predictor",
219
220
221
222
223
"This object is a tool that takes in an image region containing some object and \
outputs a set of point locations that define the pose of the object. The classic \
example of this is human face pose prediction, where you take an image of a human \
face as input and are expected to identify the locations of important facial \
landmarks such as the corners of the mouth and eyes, tip of the nose, and so forth.")
224
225
        .def(py::init())
        .def(py::init(&load_object_from_file<type>),
226
227
"Loads a shape_predictor from a file that contains the output of the \n\
train_shape_predictor() routine.")
228
        .def("__call__", &run_predictor, py::arg("image"), py::arg("box"),
229
230
231
232
233
234
"requires \n\
    - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\
      image. \n\
    - box is the bounding box to begin the shape prediction inside. \n\
ensures \n\
    - This function runs the shape predictor on the input image and returns \n\
235
      a single full_object_detection.")
236
237
        .def("save", save_shape_predictor, py::arg("predictor_output_filename"), "Save a shape_predictor to the provided path.")
        .def(py::pickle(&getstate<type>, &setstate<type>));
238
239
    }
    {
240
241
    m.def("train_shape_predictor", train_shape_predictor_on_images_py,
        py::arg("images"), py::arg("object_detections"), py::arg("options"),
242
"requires \n\
243
    - options.lambda_param > 0 \n\
244
    - 0 < options.nu <= 1 \n\
245
246
247
248
249
250
    - options.feature_pool_region_padding >= 0 \n\
    - len(images) == len(object_detections) \n\
    - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
    - object_detections should be a list of lists of dlib.full_object_detection objects. \
      Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
ensures \n\
251
252
    - Uses dlib's shape_predictor_trainer object to train a \n\
      shape_predictor based on the provided labeled images, full_object_detections, and options.\n\
253
    - The trained shape_predictor is returned");
254

255
256
    m.def("train_shape_predictor", train_shape_predictor,
        py::arg("dataset_filename"), py::arg("predictor_output_filename"), py::arg("options"),
257
"requires \n\
258
    - options.lambda_param > 0 \n\
259
    - 0 < options.nu <= 1 \n\
260
261
    - options.feature_pool_region_padding >= 0 \n\
ensures \n\
262
    - Uses dlib's shape_predictor_trainer to train a \n\
263
      shape_predictor based on the labeled images in the XML file \n\
264
      dataset_filename and the provided options.  This function assumes the file dataset_filename is in the \n\
265
266
267
      XML format produced by dlib's save_image_dataset_metadata() routine. \n\
    - The trained shape predictor is serialized to the file predictor_output_filename.");

268
269
    m.def("test_shape_predictor", test_shape_predictor_py,
        py::arg("dataset_filename"), py::arg("predictor_filename"),
270
271
272
273
274
275
276
277
278
279
280
281
"ensures \n\
    - Loads an image dataset from dataset_filename.  We assume dataset_filename is \n\
      a file using the XML format written by save_image_dataset_metadata(). \n\
    - Loads a shape_predictor from the file predictor_filename.  This means \n\
      predictor_filename should be a file produced by the train_shape_predictor() \n\
      routine. \n\
    - This function tests the predictor against the dataset and returns the \n\
      mean average error of the detector.  In fact, The \n\
      return value of this function is identical to that of dlib's \n\
      shape_predictor_trainer() routine.  Therefore, see the documentation \n\
      for shape_predictor_trainer() for a detailed definition of the mean average error.");

282
283
    m.def("test_shape_predictor", test_shape_predictor_with_images_no_scales_py,
            py::arg("images"), py::arg("detections"), py::arg("shape_predictor"),
284
285
286
287
288
289
"requires \n\
    - len(images) == len(object_detections) \n\
    - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
    - object_detections should be a list of lists of dlib.full_object_detection objects. \
      Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
 ensures \n\
290
    - shape_predictor should be a file produced by the train_shape_predictor()  \n\
291
292
293
294
295
296
297
298
      routine. \n\
    - This function tests the predictor against the dataset and returns the \n\
      mean average error of the detector.  In fact, The \n\
      return value of this function is identical to that of dlib's \n\
      shape_predictor_trainer() routine.  Therefore, see the documentation \n\
      for shape_predictor_trainer() for a detailed definition of the mean average error.");


299
300
    m.def("test_shape_predictor", test_shape_predictor_with_images_py,
            py::arg("images"), py::arg("detections"), py::arg("scales"), py::arg("shape_predictor"),
301
302
303
304
305
306
307
308
309
310
"requires \n\
    - len(images) == len(object_detections) \n\
    - len(object_detections) == len(scales) \n\
    - for every sublist in object_detections: len(object_detections[i]) == len(scales[i]) \n\
    - scales is a list of floating point scales that each predicted part location \
      should be divided by. Useful for normalization. \n\
    - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
    - object_detections should be a list of lists of dlib.full_object_detection objects. \
      Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
 ensures \n\
311
    - shape_predictor should be a file produced by the train_shape_predictor()  \n\
312
313
314
315
316
317
318
319
      routine. \n\
    - This function tests the predictor against the dataset and returns the \n\
      mean average error of the detector.  In fact, The \n\
      return value of this function is identical to that of dlib's \n\
      shape_predictor_trainer() routine.  Therefore, see the documentation \n\
      for shape_predictor_trainer() for a detailed definition of the mean average error.");
    }
}