cnn_face_detector.cpp 6.33 KB
Newer Older
1
2
3
// Copyright (C) 2017  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

Davis King's avatar
Davis King committed
4
#include "opaque_types.h"
5
6
7
8
9
#include <dlib/python.h>
#include <dlib/matrix.h>
#include <dlib/dnn.h>
#include <dlib/image_transforms.h>
#include "indexing.h"
10
#include <pybind11/stl_bind.h>
11
12
13

using namespace dlib;
using namespace std;
14
15
16

namespace py = pybind11;

17
18
19
20
21
22
23
24
25
26
27

class cnn_face_detection_model_v1
{

public:

    cnn_face_detection_model_v1(const std::string& model_filename)
    {
        deserialize(model_filename) >> net;
    }

28
    std::vector<mmod_rect> detect (
29
        py::object pyimage,
30
31
32
33
        const int upsample_num_times
    )
    {
        pyramid_down<2> pyr;
34
        std::vector<mmod_rect> rects;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

        // Copy the data into dlib based objects
        matrix<rgb_pixel> image;
        if (is_gray_python_image(pyimage))
            assign_image(image, numpy_gray_image(pyimage));
        else if (is_rgb_python_image(pyimage))
            assign_image(image, numpy_rgb_image(pyimage));
        else
            throw dlib::error("Unsupported image type, must be 8bit gray or RGB image.");

        // Upsampling the image will allow us to detect smaller faces but will cause the
        // program to use more RAM and run longer.
        unsigned int levels = upsample_num_times;
        while (levels > 0)
        {
            levels--;
            pyramid_up(image, pyr);
        }

        auto dets = net(image);

        // Scale the detection locations back to the original image size
        // if the image was upscaled.
        for (auto&& d : dets) {
            d.rect = pyr.rect_down(d.rect, upsample_num_times);
60
            rects.push_back(d);
61
62
63
64
65
        }

        return rects;
    }

66
    std::vector<std::vector<mmod_rect> > detect_mult (
67
        py::list imgs,
68
69
70
71
72
73
74
75
76
77
78
79
        const int upsample_num_times,
        const int batch_size = 128
    )
    {
        pyramid_down<2> pyr;
        std::vector<matrix<rgb_pixel> > dimgs;
        dimgs.reserve(len(imgs));

        for(int i = 0; i < len(imgs); i++)
        {
            // Copy the data into dlib based objects
            matrix<rgb_pixel> image;
80
            py::object tmp = imgs[i].cast<py::object>();
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            if (is_gray_python_image(tmp))
                assign_image(image, numpy_gray_image(tmp));
            else if (is_rgb_python_image(tmp))
                assign_image(image, numpy_rgb_image(tmp));
            else
                throw dlib::error("Unsupported image type, must be 8bit gray or RGB image.");

            for(int i = 0; i < upsample_num_times; i++)
            {
                pyramid_up(image);
            }
            dimgs.push_back(image);
        }

        for(int i = 1; i < dimgs.size(); i++)
        {
            if
            (
                dimgs[i - 1].nc() != dimgs[i].nc() ||
                dimgs[i - 1].nr() != dimgs[i].nr()
            )
                throw dlib::error("Images in list must all have the same dimensions.");
            
        }        

        auto dets = net(dimgs, batch_size);
        std::vector<std::vector<mmod_rect> > all_rects;

        for(auto&& im_dets : dets)
        {
            std::vector<mmod_rect> rects;
            rects.reserve(im_dets.size());
            for (auto&& d : im_dets) {
                d.rect = pyr.rect_down(d.rect, upsample_num_times);
                rects.push_back(d);
            }
            all_rects.push_back(rects);
        }
        
        return all_rects;
    }

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
private:

    template <long num_filters, typename SUBNET> using con5d = con<num_filters,5,5,2,2,SUBNET>;
    template <long num_filters, typename SUBNET> using con5  = con<num_filters,5,5,1,1,SUBNET>;

    template <typename SUBNET> using downsampler  = relu<affine<con5d<32, relu<affine<con5d<32, relu<affine<con5d<16,SUBNET>>>>>>>>>;
    template <typename SUBNET> using rcon5  = relu<affine<con5<45,SUBNET>>>;

    using net_type = loss_mmod<con<1,9,9,1,1,rcon5<rcon5<rcon5<downsampler<input_rgb_image_pyramid<pyramid_down<6>>>>>>>>;

    net_type net;
};

// ----------------------------------------------------------------------------------------

138
void bind_cnn_face_detection(py::module& m)
139
140
{
    {
141
142
    py::class_<cnn_face_detection_model_v1>(m, "cnn_face_detection_model_v1", "This object detects human faces in an image.  The constructor loads the face detection model from a file. You can download a pre-trained model from http://dlib.net/files/mmod_human_face_detector.dat.bz2.")
        .def(py::init<std::string>())
143
144
145
146
147
148
        .def(
            "__call__", 
            &cnn_face_detection_model_v1::detect_mult, 
            py::arg("imgs"), py::arg("upsample_num_times")=0, py::arg("batch_size")=128, 
            "takes a list of images as input returning a 2d list of mmod rectangles"
            )
149
150
151
        .def(
            "__call__", 
            &cnn_face_detection_model_v1::detect, 
152
            py::arg("img"), py::arg("upsample_num_times")=0,
153
154
155
156
157
            "Find faces in an image using a deep learning model.\n\
          - Upsamples the image upsample_num_times before running the face \n\
            detector."
            );
    }
158

159
    m.def("set_dnn_prefer_smallest_algorithms", &set_dnn_prefer_smallest_algorithms, "Tells cuDNN to use slower algorithms that use less RAM.");
160
161
162
163
164
165
166

    auto cuda = m.def_submodule("cuda", "Routines for setting CUDA specific properties.");
    cuda.def("set_device", &dlib::cuda::set_device, py::arg("device_id"), 
        "Set the active CUDA device.  It is required that 0 <= device_id < get_num_devices().");
    cuda.def("get_device", &dlib::cuda::get_device, "Get the active CUDA device.");
    cuda.def("get_num_devices", &dlib::cuda::get_num_devices, "Find out how many CUDA devices are available.");

167
168
    {
    typedef mmod_rect type;
169
    py::class_<type>(m, "mmod_rectangle", "Wrapper around a rectangle object and a detection confidence score.")
170
171
172
173
174
        .def_readwrite("rect",   &type::rect)
        .def_readwrite("confidence", &type::detection_confidence);
    }
    {
    typedef std::vector<mmod_rect> type;
175
176
    py::bind_vector<type>(m, "mmod_rectangles", "An array of mmod rectangle objects.")
        .def("extend", extend_vector_with_python_list<mmod_rect>);
177
178
179
    }
    {
    typedef std::vector<std::vector<mmod_rect> > type;
180
181
    py::bind_vector<type>(m, "mmod_rectangless", "A 2D array of mmod rectangle objects.")
        .def("extend", extend_vector_with_python_list<std::vector<mmod_rect>>);
182
    } 
183
}