dnn_semantic_segmentation_ex.cpp 6.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
    This example shows how to do semantic segmentation on an image using net pretrained
    on the PASCAL VOC2012 dataset.  For an introduction to what segmentation is, see the
    accompanying header file dnn_semantic_segmentation_ex.h.

    Instructions how to run the example:
    1. Download the PASCAL VOC2012 data, and untar it somewhere.
       http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    2. Build the dnn_semantic_segmentation_train_ex example program.
    3. Run:
       ./dnn_semantic_segmentation_train_ex /path/to/VOC2012
    4. Wait while the network is being trained.
    5. Build the dnn_semantic_segmentation_ex example program.
    6. Run:
       ./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images

    An alternative to steps 2-4 above is to download a pre-trained network
19
    from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    It would be a good idea to become familiar with dlib's DNN tooling before reading this
    example.  So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
    before reading this example program.
*/

#include "dnn_semantic_segmentation_ex.h"

#include <iostream>
#include <dlib/data_io.h>
#include <dlib/gui_widgets.h>

using namespace std;
using namespace dlib;
 
// ----------------------------------------------------------------------------------------

// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background.  Each class
// is represented using an RGB color value.  We associate each class also to an index in the
// range [0, 20], used internally by the network.  To generate nice RGB representations of
// inference results, we need to be able to convert the index values to the corresponding
// RGB values.

// Given an index in the range [0, 20], find the corresponding PASCAL VOC2012 class
// (e.g., 'dog').
const Voc2012class& find_voc2012_class(const uint16_t& index_label)
{
    return find_voc2012_class(
        [&index_label](const Voc2012class& voc2012class)
        {
            return index_label == voc2012class.index;
        }
    );
}

// Convert an index in the range [0, 20] to a corresponding RGB class label.
inline rgb_pixel index_label_to_rgb_label(uint16_t index_label)
{
    return find_voc2012_class(index_label).rgb_label;
}

// Convert an image containing indexes in the range [0, 20] to a corresponding
// image containing RGB class labels.
void index_label_image_to_rgb_label_image(
    const matrix<uint16_t>& index_label_image,
    matrix<rgb_pixel>& rgb_label_image
)
{
    const long nr = index_label_image.nr();
    const long nc = index_label_image.nc();

    rgb_label_image.set_size(nr, nc);

    for (long r = 0; r < nr; ++r)
    {
        for (long c = 0; c < nc; ++c)
        {
            rgb_label_image(r, c) = index_label_to_rgb_label(index_label_image(r, c));
        }
    }
}

// Find the most prominent class label from amongst the per-pixel predictions.
std::string get_most_prominent_non_background_classlabel(const matrix<uint16_t>& index_label_image)
{
    const long nr = index_label_image.nr();
    const long nc = index_label_image.nc();

    std::vector<unsigned int> counters(class_count);

    for (long r = 0; r < nr; ++r)
    {
        for (long c = 0; c < nc; ++c)
        {
            const uint16_t label = index_label_image(r, c);
            ++counters[label];
        }
    }

    const auto max_element = std::max_element(counters.begin() + 1, counters.end());
    const uint16_t most_prominent_index_label = max_element - counters.begin();

    return find_voc2012_class(most_prominent_index_label).classlabel;
}

// ----------------------------------------------------------------------------------------

int main(int argc, char** argv) try
{
    if (argc != 2)
    {
        cout << "You call this program like this: " << endl;
        cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl;
        cout << endl;
114
        cout << "You will also need a trained '" << semantic_segmentation_net_filename << "' file." << endl;
115
116
        cout << "You can either train it yourself (see example program" << endl;
        cout << "dnn_semantic_segmentation_train_ex), or download a" << endl;
117
        cout << "copy from here: http://dlib.net/files/" << semantic_segmentation_net_filename << endl;
118
119
120
121
122
        return 1;
    }

    // Read the file containing the trained network from the working directory.
    anet_type net;
123
    deserialize(semantic_segmentation_net_filename) >> net;
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    // Show inference results in a window.
    image_window win;

    matrix<rgb_pixel> input_image;
    matrix<uint16_t> index_label_image;
    matrix<rgb_pixel> rgb_label_image;

    // Find supported image files.
    const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1],
        dlib::match_endings(".jpeg .jpg .png"));

    cout << "Found " << files.size() << " images, processing..." << endl;

    for (const file& file : files)
    {
        // Load the input image.
        load_image(input_image, file.full_name());

        // Create predictions for each pixel. At this point, the type of each prediction
        // is an index (a value between 0 and 20). Note that the net may return an image
        // that is not exactly the same size as the input.
        const matrix<uint16_t> temp = net(input_image);

        // Crop the returned image to be exactly the same size as the input.
        const chip_details chip_details(
            centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()),
            chip_dims(input_image.nr(), input_image.nc())
        );
        extract_image_chip(temp, chip_details, index_label_image, interpolate_nearest_neighbor());

        // Convert the indexes to RGB values.
        index_label_image_to_rgb_label_image(index_label_image, rgb_label_image);

        // Show the input image on the left, and the predicted RGB labels on the right.
        win.set_image(join_rows(input_image, rgb_label_image));

        // Find the most prominent class label from amongst the per-pixel predictions.
        const std::string classlabel = get_most_prominent_non_background_classlabel(index_label_image);

        cout << file.name() << " : " << classlabel << " - hit enter to process the next image";
        cin.get();
    }
}
catch(std::exception& e)
{
    cout << e.what() << endl;
}