resnet.h 5.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#ifndef ResNet_H
#define ResNet_H

#include <dlib/dnn.h>

// BATCHNORM must be bn_con or affine layer
template<template<typename> class BATCHNORM>
struct resnet
{
    // the resnet basic block, where BN is bn_con or affine
    template<long num_filters, template<typename> class BN, int stride, typename SUBNET>
    using basicblock = BN<dlib::con<num_filters, 3, 3, 1, 1,
                  dlib::relu<BN<dlib::con<num_filters, 3, 3, stride, stride, SUBNET>>>>>;

    // the resnet bottleneck block
    template<long num_filters, template<typename> class BN, int stride, typename SUBNET>
    using bottleneck = BN<dlib::con<4 * num_filters, 1, 1, 1, 1,
                  dlib::relu<BN<dlib::con<num_filters, 3, 3, stride, stride,
                  dlib::relu<BN<dlib::con<num_filters, 1, 1, 1, 1, SUBNET>>>>>>>>;

    // the resnet residual
    template<
        template<long, template<typename> class, int, typename> class BLOCK, // basicblock or bottleneck
        long num_filters,
        template<typename> class BN, // bn_con or affine
        typename SUBNET
    > // adds the block to the result of tag1 (the subnet)
    using residual = dlib::add_prev1<BLOCK<num_filters, BN, 1, dlib::tag1<SUBNET>>>;

    // a resnet residual that does subsampling on both paths
    template<
        template<long, template<typename> class, int, typename> class BLOCK, // basicblock or bottleneck
        long num_filters,
        template<typename> class BN, // bn_con or affine
        typename SUBNET
    >
    using residual_down = dlib::add_prev2<dlib::avg_pool<2, 2, 2, 2,
                          dlib::skip1<dlib::tag2<BLOCK<num_filters, BN, 2,
                          dlib::tag1<SUBNET>>>>>>;

    // residual block with optional downsampling and custom regularization (bn_con or affine)
    template<
        template<template<long, template<typename> class, int, typename> class, long, template<typename>class, typename> class RESIDUAL,
        template<long, template<typename> class, int, typename> class BLOCK,
        long num_filters,
        template<typename> class BN, // bn_con or affine
        typename SUBNET
    >
    using residual_block = dlib::relu<RESIDUAL<BLOCK, num_filters, BN, SUBNET>>;

    template<long num_filters, typename SUBNET>
    using resbasicblock_down = residual_block<residual_down, basicblock, num_filters, BATCHNORM, SUBNET>;
    template<long num_filters, typename SUBNET>
    using resbottleneck_down = residual_block<residual_down, bottleneck, num_filters, BATCHNORM, SUBNET>;

    // some definitions to allow the use of the repeat layer
    template<typename SUBNET> using resbasicblock_512 = residual_block<residual, basicblock, 512, BATCHNORM, SUBNET>;
    template<typename SUBNET> using resbasicblock_256 = residual_block<residual, basicblock, 256, BATCHNORM, SUBNET>;
    template<typename SUBNET> using resbasicblock_128 = residual_block<residual, basicblock, 128, BATCHNORM, SUBNET>;
    template<typename SUBNET> using resbasicblock_64  = residual_block<residual, basicblock,  64, BATCHNORM, SUBNET>;
    template<typename SUBNET> using resbottleneck_512 = residual_block<residual, bottleneck, 512, BATCHNORM, SUBNET>;
    template<typename SUBNET> using resbottleneck_256 = residual_block<residual, bottleneck, 256, BATCHNORM, SUBNET>;
    template<typename SUBNET> using resbottleneck_128 = residual_block<residual, bottleneck, 128, BATCHNORM, SUBNET>;
    template<typename SUBNET> using resbottleneck_64  = residual_block<residual, bottleneck,  64, BATCHNORM, SUBNET>;

    // common processing for standard resnet inputs
    template<template<typename> class BN, typename INPUT>
    using input_processing = dlib::max_pool<3, 3, 2, 2, dlib::relu<BN<dlib::con<64, 7, 7, 2, 2, INPUT>>>>;

    // the resnet backbone with basicblocks
    template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
    using backbone_basicblock =
        dlib::repeat<nb_512, resbasicblock_512, resbasicblock_down<512,
        dlib::repeat<nb_256, resbasicblock_256, resbasicblock_down<256,
        dlib::repeat<nb_128, resbasicblock_128, resbasicblock_down<128,
        dlib::repeat<nb_64,  resbasicblock_64, input_processing<BATCHNORM, INPUT>>>>>>>>;

    // the resnet backbone with bottlenecks
    template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
    using backbone_bottleneck =
        dlib::repeat<nb_512, resbottleneck_512, resbottleneck_down<512,
        dlib::repeat<nb_256, resbottleneck_256, resbottleneck_down<256,
        dlib::repeat<nb_128, resbottleneck_128, resbottleneck_down<128,
        dlib::repeat<nb_64,  resbottleneck_64, input_processing<BATCHNORM, INPUT>>>>>>>>;

    // the backbones for the classic architectures
    template<typename INPUT> using backbone_18  = backbone_basicblock<1, 1, 1, 2, INPUT>;
    template<typename INPUT> using backbone_34  = backbone_basicblock<2, 5, 3, 3, INPUT>;
    template<typename INPUT> using backbone_50  = backbone_bottleneck<2, 5, 3, 3, INPUT>;
    template<typename INPUT> using backbone_101 = backbone_bottleneck<2, 22, 3, 3, INPUT>;
    template<typename INPUT> using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>;

    // the typical classifier models
    using n18  = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_18<dlib::input_rgb_image>>>>;
    using n34  = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_34<dlib::input_rgb_image>>>>;
    using n50  = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_50<dlib::input_rgb_image>>>>;
    using n101 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_101<dlib::input_rgb_image>>>>;
    using n152 = dlib::loss_multiclass_log<dlib::fc<1000, dlib::avg_pool_everything<backbone_152<dlib::input_rgb_image>>>>;
};

#endif // ResNet_H