dnn_imagenet_ex.cpp 5.77 KB
Newer Older
1
2
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
3
4
5
    This example shows how to classify an image into one of the 1000 imagenet
    categories using the deep learning tools from the dlib C++ Library.  We will
    use the pretrained ResNet34 model available on the dlib website.
6

7
8
9
10
11
12
    The ResNet34 architecture is from the paper Deep Residual Learning for Image
    Recognition by He, Zhang, Ren, and Sun.  The model file that comes with dlib
    was trained using the dnn_imagenet_train_ex.cpp program on a Titan X for
    about 2 weeks.  This pretrained model has a top5 error of 7.572% on the 2012
    imagenet validation dataset.

13
14
    For an introduction to dlib's DNN module read the dnn_introduction_ex.cpp and
    dnn_introduction2_ex.cpp example programs.
Davis King's avatar
Davis King committed
15

16
    
17
18
19
    Finally, these tools will use CUDA and cuDNN to drastically accelerate
    network training and testing.  CMake should automatically find them if they
    are installed and configure things appropriately.  If not, the program will
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    still run but will be much slower to execute.
*/



#include <dlib/dnn.h>
#include <iostream>
#include <dlib/data_io.h>
#include <dlib/gui_widgets.h>
#include <dlib/image_transforms.h>

using namespace std;
using namespace dlib;
 
// ----------------------------------------------------------------------------------------

36
// This block of statements defines the resnet-34 network
37
38
39
40
41
42
43
44
45
46
47
48
49
50

template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual = add_prev1<block<N,BN,1,tag1<SUBNET>>>;

template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual_down = add_prev2<avg_pool<2,2,2,2,skip1<tag2<block<N,BN,2,tag1<SUBNET>>>>>>;

template <int N, template <typename> class BN, int stride, typename SUBNET> 
using block  = BN<con<N,3,3,1,1,relu<BN<con<N,3,3,stride,stride,SUBNET>>>>>;

template <int N, typename SUBNET> using ares      = relu<residual<block,N,affine,SUBNET>>;
template <int N, typename SUBNET> using ares_down = relu<residual_down<block,N,affine,SUBNET>>;


51
using anet_type = loss_multiclass_log<fc<1000,avg_pool_everything<
52
53
54
55
56
57
                            ares<512,ares<512,ares_down<512,
                            ares<256,ares<256,ares<256,ares<256,ares<256,ares_down<256,
                            ares<128,ares<128,ares<128,ares_down<128,
                            ares<64,ares<64,ares<64,
                            max_pool<3,3,2,2,relu<affine<con<64,7,7,2,2,
                            input_rgb_image_sized<227>
58
                            >>>>>>>>>>>>>>>>>>>>>>>;
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

// ----------------------------------------------------------------------------------------

rectangle make_random_cropping_rect_resnet(
    const matrix<rgb_pixel>& img,
    dlib::rand& rnd
)
{
    // figure out what rectangle we want to crop from the image
    double mins = 0.466666666, maxs = 0.875;
    auto scale = mins + rnd.get_random_double()*(maxs-mins);
    auto size = scale*std::min(img.nr(), img.nc());
    rectangle rect(size, size);
    // randomly shift the box around
    point offset(rnd.get_random_32bit_number()%(img.nc()-rect.width()),
                 rnd.get_random_32bit_number()%(img.nr()-rect.height()));
    return move_rect(rect, offset);
}

// ----------------------------------------------------------------------------------------

void randomly_crop_images (
    const matrix<rgb_pixel>& img,
    dlib::array<matrix<rgb_pixel>>& crops,
    dlib::rand& rnd,
    long num_crops
)
{
    std::vector<chip_details> dets;
    for (long i = 0; i < num_crops; ++i)
    {
        auto rect = make_random_cropping_rect_resnet(img, rnd);
        dets.push_back(chip_details(rect, chip_dims(227,227)));
    }

    extract_image_chips(img, dets, crops);

    for (auto&& img : crops)
    {
        // Also randomly flip the image
        if (rnd.get_random_double() > 0.5)
            img = fliplr(img);

        // And then randomly adjust the colors.
        apply_random_color_offset(img, rnd);
    }
}

// ----------------------------------------------------------------------------------------

int main(int argc, char** argv) try
{
111
112
113
    if (argc == 1)
    {
        cout << "Give this program image files as command line arguments.\n" << endl;
Davis King's avatar
Davis King committed
114
115
        cout << "You will also need a copy of the file resnet34_1000_imagenet_classifier.dnn " << endl;
        cout << "available at http://dlib.net/files/resnet34_1000_imagenet_classifier.dnn.bz2" << endl;
116
117
118
119
        cout << endl;
        return 1;
    }

120
121
122
123
124
125
126
127
128
129
130
131
132
133
    std::vector<string> labels;
    anet_type net;
    deserialize("resnet34_1000_imagenet_classifier.dnn") >> net >> labels;


    softmax<anet_type::subnet_type> snet; 
    snet.subnet() = net.subnet();

    dlib::array<matrix<rgb_pixel>> images;
    matrix<rgb_pixel> img, crop;

    dlib::rand rnd;
    image_window win;

134
    // read images from the command prompt and print the top 5 best labels for each.
135
136
137
138
    for (int i = 1; i < argc; ++i)
    {
        load_image(img, argv[i]);
        const int num_crops = 16;
139
140
        // Grab 16 random crops from the image.  We will run all of them through the
        // network and average the results.
141
        randomly_crop_images(img, images, rnd, num_crops);
142
        // p(i) == the probability the image contains object of class i.
143
144
145
146
147
148
149
150
151
152
        matrix<float,1,1000> p = sum_rows(mat(snet(images.begin(), images.end())))/num_crops;

        win.set_image(img);
        for (int k = 0; k < 5; ++k)
        {
            unsigned long predicted_label = index_of_max(p);
            cout << p(predicted_label) << ": " << labels[predicted_label] << endl;
            p(predicted_label) = 0;
        }

153
        cout << "Hit enter to process the next image";
154
155
156
157
158
159
160
161
162
        cin.get();
    }

}
catch(std::exception& e)
{
    cout << e.what() << endl;
}