layers.h 72.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNn_LAYERS_H_
#define DLIB_DNn_LAYERS_H_

#include "layers_abstract.h"
#include "tensor.h"
#include "core.h"
#include <iostream>
#include <string>
Davis King's avatar
Davis King committed
11
12
#include "../rand.h"
#include "../string.h"
13
#include "tensor_tools.h"
14
#include "../vectorstream.h"
15
#include "utilities.h"
16
17
18
19
20
21
22


namespace dlib
{

// ----------------------------------------------------------------------------------------

23
24
25
26
27
    template <
        long _num_filters,
        long _nr,
        long _nc,
        int _stride_y,
28
29
30
        int _stride_x,
        int _padding_y = _stride_y!=1? 0 : _nr/2,
        int _padding_x = _stride_x!=1? 0 : _nc/2
31
        >
32
33
34
    class con_
    {
    public:
35

36
37
38
39
40
        static_assert(_num_filters > 0, "The number of filters must be > 0");
        static_assert(_nr > 0, "The number of rows in a filter must be > 0");
        static_assert(_nc > 0, "The number of columns in a filter must be > 0");
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
41
42
        static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size.");
        static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
43

Davis King's avatar
Davis King committed
44
        con_(
45
        ) : 
46
47
48
49
            learning_rate_multiplier(1),
            weight_decay_multiplier(1),
            bias_learning_rate_multiplier(1),
            bias_weight_decay_multiplier(0),
50
51
            padding_y_(_padding_y),
            padding_x_(_padding_x)
52
53
        {}

54
55
56
57
58
        long num_filters() const { return _num_filters; }
        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }
59
60
        long padding_y() const { return padding_y_; }
        long padding_x() const { return padding_x_; }
61

62
63
64
65
66
67
68
69
70
71
72
        double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
        double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
        void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
        void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }

        double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
        double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
        void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
        void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }


Davis King's avatar
Davis King committed
73
74
75
76
77
        con_ (
            const con_& item
        ) : 
            params(item.params),
            filters(item.filters),
78
            biases(item.biases),
79
80
81
82
            learning_rate_multiplier(item.learning_rate_multiplier),
            weight_decay_multiplier(item.weight_decay_multiplier),
            bias_learning_rate_multiplier(item.bias_learning_rate_multiplier),
            bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
83
84
            padding_y_(item.padding_y_),
            padding_x_(item.padding_x_)
Davis King's avatar
Davis King committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        {
            // this->conv is non-copyable and basically stateless, so we have to write our
            // own copy to avoid trying to copy it and getting an error.
        }

        con_& operator= (
            const con_& item
        )
        {
            if (this == &item)
                return *this;

            // this->conv is non-copyable and basically stateless, so we have to write our
            // own copy to avoid trying to copy it and getting an error.
            params = item.params;
            filters = item.filters;
            biases = item.biases;
102
103
            padding_y_ = item.padding_y_;
            padding_x_ = item.padding_x_;
104
105
106
107
            learning_rate_multiplier = item.learning_rate_multiplier;
            weight_decay_multiplier = item.weight_decay_multiplier;
            bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
            bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
Davis King's avatar
Davis King committed
108
109
110
            return *this;
        }

Davis King's avatar
Davis King committed
111
112
        template <typename SUBNET>
        void setup (const SUBNET& sub)
113
        {
114
115
            long num_inputs = _nr*_nc*sub.get_output().k();
            long num_outputs = _num_filters;
Davis King's avatar
Davis King committed
116
            // allocate params for the filters and also for the filter bias values.
117
            params.set_size(num_inputs*_num_filters + _num_filters);
Davis King's avatar
Davis King committed
118

119
            dlib::rand rnd(std::rand());
Davis King's avatar
Davis King committed
120
121
            randomize_parameters(params, num_inputs+num_outputs, rnd);

122
123
            filters = alias_tensor(_num_filters, sub.get_output().k(), _nr, _nc);
            biases = alias_tensor(1,_num_filters);
Davis King's avatar
Davis King committed
124
125
126

            // set the initial bias values to zero
            biases(params,filters.size()) = 0;
127
128
        }

Davis King's avatar
Davis King committed
129
130
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
131
        {
Davis King's avatar
Davis King committed
132
133
134
            conv(output,
                sub.get_output(),
                filters(params,0),
135
                _stride_y,
136
                _stride_x,
137
138
                padding_y_,
                padding_x_
139
                );
Davis King's avatar
Davis King committed
140
141

            tt::add(1,output,1,biases(params,filters.size()));
142
143
        } 

Davis King's avatar
Davis King committed
144
        template <typename SUBNET>
145
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
146
        {
Davis King's avatar
Davis King committed
147
            conv.get_gradient_for_data (gradient_input, filters(params,0), sub.get_gradient_input());
148
149
150
151
152
153
154
155
            // no point computing the parameter gradients if they won't be used.
            if (learning_rate_multiplier != 0)
            {
                auto filt = filters(params_grad,0);
                conv.get_gradient_for_filters (gradient_input, sub.get_output(), filt);
                auto b = biases(params_grad, filters.size());
                tt::assign_conv_bias_gradient(b, gradient_input);
            }
156
157
158
159
160
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
161
162
        friend void serialize(const con_& item, std::ostream& out)
        {
163
            serialize("con_4", out);
Davis King's avatar
Davis King committed
164
            serialize(item.params, out);
165
166
167
168
169
            serialize(_num_filters, out);
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
170
171
            serialize(item.padding_y_, out);
            serialize(item.padding_x_, out);
Davis King's avatar
Davis King committed
172
173
            serialize(item.filters, out);
            serialize(item.biases, out);
174
175
176
177
            serialize(item.learning_rate_multiplier, out);
            serialize(item.weight_decay_multiplier, out);
            serialize(item.bias_learning_rate_multiplier, out);
            serialize(item.bias_weight_decay_multiplier, out);
Davis King's avatar
Davis King committed
178
179
180
181
182
183
        }

        friend void deserialize(con_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
184
185
186
187
188
            long num_filters;
            long nr;
            long nc;
            int stride_y;
            int stride_x;
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            if (version == "con_4")
            {
                deserialize(item.params, in);
                deserialize(num_filters, in);
                deserialize(nr, in);
                deserialize(nc, in);
                deserialize(stride_y, in);
                deserialize(stride_x, in);
                deserialize(item.padding_y_, in);
                deserialize(item.padding_x_, in);
                deserialize(item.filters, in);
                deserialize(item.biases, in);
                deserialize(item.learning_rate_multiplier, in);
                deserialize(item.weight_decay_multiplier, in);
                deserialize(item.bias_learning_rate_multiplier, in);
                deserialize(item.bias_weight_decay_multiplier, in);
                if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
                if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
                if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
                if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
                if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
                if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
                if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
212
213
214
215
216
            }
            else
            {
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
            }
Davis King's avatar
Davis King committed
217
218
        }

219
220
221
222
223
224
225
226
227

        friend std::ostream& operator<<(std::ostream& out, const con_& item)
        {
            out << "con\t ("
                << "num_filters="<<_num_filters
                << ", nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
228
229
                << ", padding_y="<<item.padding_y_
                << ", padding_x="<<item.padding_x_
230
                << ")";
231
232
233
234
            out << " learning_rate_mult="<<item.learning_rate_multiplier;
            out << " weight_decay_mult="<<item.weight_decay_multiplier;
            out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
            out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
235
236
237
            return out;
        }

Davis King's avatar
Davis King committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        friend void to_xml(const con_& item, std::ostream& out)
        {
            out << "<con"
                << " num_filters='"<<_num_filters<<"'"
                << " nr='"<<_nr<<"'"
                << " nc='"<<_nc<<"'"
                << " stride_y='"<<_stride_y<<"'"
                << " stride_x='"<<_stride_x<<"'"
                << " padding_y='"<<item.padding_y_<<"'"
                << " padding_x='"<<item.padding_x_<<"'"
                << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
                << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
                << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
                << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n";
            out << mat(item.params);
            out << "</con>";
        }
255

256
257
258
    private:

        resizable_tensor params;
Davis King's avatar
Davis King committed
259
260
261
        alias_tensor filters, biases;

        tt::tensor_conv conv;
262
263
264
265
        double learning_rate_multiplier;
        double weight_decay_multiplier;
        double bias_learning_rate_multiplier;
        double bias_weight_decay_multiplier;
Davis King's avatar
Davis King committed
266

267
268
269
270
271
        // These are here only because older versions of con (which you might encounter
        // serialized to disk) used different padding settings.
        int padding_y_;
        int padding_x_;

272
273
    };

274
275
276
277
278
279
280
281
282
    template <
        long num_filters,
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
283

Davis King's avatar
Davis King committed
284
285
// ----------------------------------------------------------------------------------------

286
287
288
289
    template <
        long _nr,
        long _nc,
        int _stride_y,
290
291
292
        int _stride_x,
        int _padding_y = _stride_y!=1? 0 : _nr/2,
        int _padding_x = _stride_x!=1? 0 : _nc/2
293
        >
Davis King's avatar
Davis King committed
294
295
    class max_pool_
    {
296
297
        static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
        static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
298
299
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
Davis King's avatar
Davis King committed
300
        static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)), 
301
            "The padding must be smaller than the filter size, unless the filters size is 0.");
Davis King's avatar
Davis King committed
302
        static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)), 
303
            "The padding must be smaller than the filter size, unless the filters size is 0.");
Davis King's avatar
Davis King committed
304
305
306
307
    public:


        max_pool_(
308
309
310
311
        ) :
            padding_y_(_padding_y),
            padding_x_(_padding_x)
        {}
Davis King's avatar
Davis King committed
312
313
314
315
316

        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }
317
318
        long padding_y() const { return padding_y_; }
        long padding_x() const { return padding_x_; }
Davis King's avatar
Davis King committed
319
320

        max_pool_ (
321
322
323
324
            const max_pool_& item
        )  :
            padding_y_(item.padding_y_),
            padding_x_(item.padding_x_)
Davis King's avatar
Davis King committed
325
326
327
328
329
330
331
332
333
334
335
336
        {
            // this->mp is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
        }

        max_pool_& operator= (
            const max_pool_& item
        )
        {
            if (this == &item)
                return *this;

337
338
339
            padding_y_ = item.padding_y_;
            padding_x_ = item.padding_x_;

Davis King's avatar
Davis King committed
340
341
342
343
344
345
346
347
348
349
350
351
352
            // this->mp is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
            return *this;
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
353
            mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(), 
354
355
356
                                 _nc!=0?_nc:sub.get_output().nc(),
                                 _stride_y, _stride_x, padding_y_, padding_x_);

Davis King's avatar
Davis King committed
357
358
359
360
361
362
            mp(output, sub.get_output());
        } 

        template <typename SUBNET>
        void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
363
            mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(), 
364
365
366
                                 _nc!=0?_nc:sub.get_output().nc(),
                                 _stride_y, _stride_x, padding_y_, padding_x_);

Davis King's avatar
Davis King committed
367
368
369
370
371
372
373
374
            mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const max_pool_& item, std::ostream& out)
        {
375
            serialize("max_pool_2", out);
376
377
378
379
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
380
381
            serialize(item.padding_y_, out);
            serialize(item.padding_x_, out);
Davis King's avatar
Davis King committed
382
383
384
385
386
387
        }

        friend void deserialize(max_pool_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
388
389
390
391
            long nr;
            long nc;
            int stride_y;
            int stride_x;
392
            if (version == "max_pool_2")
393
394
395
396
397
398
399
400
401
402
403
404
            {
                deserialize(nr, in);
                deserialize(nc, in);
                deserialize(stride_y, in);
                deserialize(stride_x, in);
                deserialize(item.padding_y_, in);
                deserialize(item.padding_x_, in);
            }
            else
            {
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
            }
405

406
407
            if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_");
            if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_");
408
409
410
411
            if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
            if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
            if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
            if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
Davis King's avatar
Davis King committed
412
413
        }

414
415
416
417
418
419
420
        friend std::ostream& operator<<(std::ostream& out, const max_pool_& item)
        {
            out << "max_pool ("
                << "nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
421
422
                << ", padding_y="<<item.padding_y_
                << ", padding_x="<<item.padding_x_
423
424
425
426
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
427
428
429
430
431
432
433
434
435
436
437
438
        friend void to_xml(const max_pool_& item, std::ostream& out)
        {
            out << "<max_pool"
                << " nr='"<<_nr<<"'"
                << " nc='"<<_nc<<"'"
                << " stride_y='"<<_stride_y<<"'"
                << " stride_x='"<<_stride_x<<"'"
                << " padding_y='"<<item.padding_y_<<"'"
                << " padding_x='"<<item.padding_x_<<"'"
                << "/>\n";
        }

439

Davis King's avatar
Davis King committed
440
441
442
    private:


443
        tt::pooling mp;
Davis King's avatar
Davis King committed
444
        resizable_tensor params;
445
446
447

        int padding_y_;
        int padding_x_;
Davis King's avatar
Davis King committed
448
449
    };

450
451
452
453
454
455
456
457
    template <
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
Davis King's avatar
Davis King committed
458

459
460
461
462
463
    template <
        typename SUBNET
        >
    using max_pool_everything = add_layer<max_pool_<0,0,1,1>, SUBNET>;

464
465
// ----------------------------------------------------------------------------------------

466
467
468
469
    template <
        long _nr,
        long _nc,
        int _stride_y,
470
471
472
        int _stride_x,
        int _padding_y = _stride_y!=1? 0 : _nr/2,
        int _padding_x = _stride_x!=1? 0 : _nc/2
473
        >
474
475
476
    class avg_pool_
    {
    public:
477
478
        static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
        static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
479
480
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
Davis King's avatar
Davis King committed
481
        static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)), 
482
            "The padding must be smaller than the filter size, unless the filters size is 0.");
Davis King's avatar
Davis King committed
483
        static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)), 
484
            "The padding must be smaller than the filter size, unless the filters size is 0.");
485
486

        avg_pool_(
487
488
489
490
        ) :
            padding_y_(_padding_y),
            padding_x_(_padding_x)
        {}
491
492
493
494
495

        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }
496
497
        long padding_y() const { return padding_y_; }
        long padding_x() const { return padding_x_; }
498
499

        avg_pool_ (
500
501
502
503
            const avg_pool_& item
        )  :
            padding_y_(item.padding_y_),
            padding_x_(item.padding_x_)
504
505
506
507
508
509
510
511
512
513
514
515
        {
            // this->ap is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
        }

        avg_pool_& operator= (
            const avg_pool_& item
        )
        {
            if (this == &item)
                return *this;

516
517
518
            padding_y_ = item.padding_y_;
            padding_x_ = item.padding_x_;

519
520
521
522
523
524
525
526
527
528
529
530
531
            // this->ap is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
            return *this;
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
532
533
534
535
            ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(), 
                                 _nc!=0?_nc:sub.get_output().nc(),
                                 _stride_y, _stride_x, padding_y_, padding_x_);

536
537
538
539
540
541
            ap(output, sub.get_output());
        } 

        template <typename SUBNET>
        void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
542
543
544
545
            ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(), 
                                 _nc!=0?_nc:sub.get_output().nc(),
                                 _stride_y, _stride_x, padding_y_, padding_x_);

546
547
548
549
550
551
552
553
            ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const avg_pool_& item, std::ostream& out)
        {
554
            serialize("avg_pool_2", out);
555
556
557
558
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
559
560
            serialize(item.padding_y_, out);
            serialize(item.padding_x_, out);
561
562
563
564
565
566
        }

        friend void deserialize(avg_pool_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
567
568
569
570
571

            long nr;
            long nc;
            int stride_y;
            int stride_x;
572
            if (version == "avg_pool_2")
573
574
575
576
577
578
579
580
581
582
583
584
            {
                deserialize(nr, in);
                deserialize(nc, in);
                deserialize(stride_y, in);
                deserialize(stride_x, in);
                deserialize(item.padding_y_, in);
                deserialize(item.padding_x_, in);
            }
            else
            {
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
            }
585

586
587
            if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_");
            if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_");
588
589
590
591
            if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
            if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
            if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
            if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
592
593
        }

594
595
596
597
598
599
600
        friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item)
        {
            out << "avg_pool ("
                << "nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
601
602
                << ", padding_y="<<item.padding_y_
                << ", padding_x="<<item.padding_x_
603
604
605
                << ")";
            return out;
        }
Davis King's avatar
Davis King committed
606
607
608
609
610
611
612
613
614
615
616
617

        friend void to_xml(const avg_pool_& item, std::ostream& out)
        {
            out << "<avg_pool"
                << " nr='"<<_nr<<"'"
                << " nc='"<<_nc<<"'"
                << " stride_y='"<<_stride_y<<"'"
                << " stride_x='"<<_stride_x<<"'"
                << " padding_y='"<<item.padding_y_<<"'"
                << " padding_x='"<<item.padding_x_<<"'"
                << "/>\n";
        }
618
619
620
621
    private:

        tt::pooling ap;
        resizable_tensor params;
622
623
624

        int padding_y_;
        int padding_x_;
625
626
    };

627
628
629
630
631
632
633
634
    template <
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
635

636
637
638
639
640
    template <
        typename SUBNET
        >
    using avg_pool_everything = add_layer<avg_pool_<0,0,1,1>, SUBNET>;

641
642
// ----------------------------------------------------------------------------------------

643
    enum layer_mode
644
    {
645
646
        CONV_MODE = 0,
        FC_MODE = 1
647
648
    };

649
650
    const double DEFAULT_BATCH_NORM_EPS = 0.00001;

651
652
653
    template <
        layer_mode mode
        >
654
655
656
    class bn_
    {
    public:
657
658
659
660
        explicit bn_(
            unsigned long window_size,
            double eps_ = DEFAULT_BATCH_NORM_EPS
        ) : 
661
662
663
            num_updates(0), 
            running_stats_window_size(window_size),
            learning_rate_multiplier(1),
664
            weight_decay_multiplier(0),
665
666
            bias_learning_rate_multiplier(1),
            bias_weight_decay_multiplier(1),
667
            eps(eps_)
668
669
        {}

670
        bn_() : bn_(1000) {}
671
672
673

        layer_mode get_mode() const { return mode; }
        unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
674
        double get_eps() const { return eps; }
675

676
677
678
679
680
        double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
        double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
        void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
        void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }

681
682
683
684
685
        double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
        double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
        void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
        void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }

686

687
688
689
        template <typename SUBNET>
        void setup (const SUBNET& sub)
        {
690
            if (mode == FC_MODE)
691
692
693
694
695
696
697
698
699
700
            {
                gamma = alias_tensor(1,
                                sub.get_output().k(),
                                sub.get_output().nr(),
                                sub.get_output().nc());
            }
            else
            {
                gamma = alias_tensor(1, sub.get_output().k());
            }
Davis King's avatar
Davis King committed
701
702
703
704
705
706
            beta = gamma;

            params.set_size(gamma.size()+beta.size());

            gamma(params,0) = 1;
            beta(params,gamma.size()) = 0;
707

708
            running_means.copy_size(gamma(params,0));
709
            running_variances.copy_size(gamma(params,0));
710
            running_means = 0;
711
            running_variances = 1;
712
            num_updates = 0;
713
714
715
716
717
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
Davis King's avatar
Davis King committed
718
719
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
720
721
            if (sub.get_output().num_samples() > 1)
            {
722
                const double decay = 1.0 - num_updates/(num_updates+1.0);
723
724
                if (num_updates <running_stats_window_size)
                    ++num_updates;
725
                if (mode == FC_MODE)
726
                    tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
727
                else 
728
                    tt::batch_normalize_conv(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
729
730
731
            }
            else // we are running in testing mode so we just linearly scale the input tensor.
            {
732
                if (mode == FC_MODE)
733
                    tt::batch_normalize_inference(eps, output, sub.get_output(), g, b, running_means, running_variances);
734
                else
735
                    tt::batch_normalize_conv_inference(eps, output, sub.get_output(), g, b, running_means, running_variances);
736
            }
737
738
739
740
741
        } 

        template <typename SUBNET>
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
        {
Davis King's avatar
Davis King committed
742
743
744
            auto g = gamma(params,0);
            auto g_grad = gamma(params_grad, 0);
            auto b_grad = beta(params_grad, gamma.size());
745
            if (mode == FC_MODE)
746
                tt::batch_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
747
            else
748
                tt::batch_normalize_conv_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
749
750
751
752
753
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
754
755
        friend void serialize(const bn_& item, std::ostream& out)
        {
756
            if (mode == CONV_MODE)
757
                serialize("bn_con2", out);
758
            else // if FC_MODE
759
                serialize("bn_fc2", out);
Davis King's avatar
Davis King committed
760
761
762
763
764
            serialize(item.params, out);
            serialize(item.gamma, out);
            serialize(item.beta, out);
            serialize(item.means, out);
            serialize(item.invstds, out);
765
            serialize(item.running_means, out);
766
            serialize(item.running_variances, out);
767
768
            serialize(item.num_updates, out);
            serialize(item.running_stats_window_size, out);
769
770
            serialize(item.learning_rate_multiplier, out);
            serialize(item.weight_decay_multiplier, out);
771
772
            serialize(item.bias_learning_rate_multiplier, out);
            serialize(item.bias_weight_decay_multiplier, out);
773
            serialize(item.eps, out);
Davis King's avatar
Davis King committed
774
775
776
777
778
779
        }

        friend void deserialize(bn_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
780
            if (mode == CONV_MODE) 
781
            {
782
783
784
785
786
787
788
                if (version != "bn_con2")
                    throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
            }
            else // must be in FC_MODE
            {
                if (version != "bn_fc2")
                    throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
789
790
            }

Davis King's avatar
Davis King committed
791
792
793
794
795
            deserialize(item.params, in);
            deserialize(item.gamma, in);
            deserialize(item.beta, in);
            deserialize(item.means, in);
            deserialize(item.invstds, in);
796
            deserialize(item.running_means, in);
797
            deserialize(item.running_variances, in);
798
799
            deserialize(item.num_updates, in);
            deserialize(item.running_stats_window_size, in);
800
801
802
803
804
            deserialize(item.learning_rate_multiplier, in);
            deserialize(item.weight_decay_multiplier, in);
            deserialize(item.bias_learning_rate_multiplier, in);
            deserialize(item.bias_weight_decay_multiplier, in);
            deserialize(item.eps, in);
Davis King's avatar
Davis King committed
805
806
        }

807
808
809
        friend std::ostream& operator<<(std::ostream& out, const bn_& item)
        {
            if (mode == CONV_MODE)
810
                out << "bn_con  ";
811
            else
812
                out << "bn_fc   ";
813
            out << " eps="<<item.eps;
814
815
            out << " learning_rate_mult="<<item.learning_rate_multiplier;
            out << " weight_decay_mult="<<item.weight_decay_multiplier;
816
817
            out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
            out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
818
819
820
            return out;
        }

Davis King's avatar
Davis King committed
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
        friend void to_xml(const bn_& item, std::ostream& out)
        {
            if (mode==CONV_MODE)
                out << "<bn_con";
            else
                out << "<bn_fc";

            out << " eps='"<<item.eps<<"'";
            out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'";
            out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
            out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'";
            out << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
            out << ">\n";

            out << mat(item.params);

            if (mode==CONV_MODE)
                out << "</bn_con>\n";
            else
                out << "</bn_fc>\n";
        }

843
844
    private:

845
846
        friend class affine_;

847
        resizable_tensor params;
Davis King's avatar
Davis King committed
848
        alias_tensor gamma, beta;
849
        resizable_tensor means, running_means;
850
        resizable_tensor invstds, running_variances;
851
852
        unsigned long num_updates;
        unsigned long running_stats_window_size;
853
854
        double learning_rate_multiplier;
        double weight_decay_multiplier;
855
856
        double bias_learning_rate_multiplier;
        double bias_weight_decay_multiplier;
857
        double eps;
858
859
860
    };

    template <typename SUBNET>
861
862
863
    using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
    template <typename SUBNET>
    using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
864

865
866
// ----------------------------------------------------------------------------------------

867
868
    enum fc_bias_mode
    {
869
870
871
872
        FC_HAS_BIAS = 0,
        FC_NO_BIAS = 1
    };

873
874
875
876
877
878
879
880
881
882
    struct num_fc_outputs
    {
        num_fc_outputs(unsigned long n) : num_outputs(n) {}
        unsigned long num_outputs;
    };

    template <
        unsigned long num_outputs_,
        fc_bias_mode bias_mode
        >
883
884
    class fc_
    {
885
886
        static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");

887
    public:
888
889
890
891
892
893
        fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0),
            learning_rate_multiplier(1),
            weight_decay_multiplier(1),
            bias_learning_rate_multiplier(1),
            bias_weight_decay_multiplier(0)
        {}
894

895
896
897
898
899
900
901
902
903
904
905
        fc_() : fc_(num_fc_outputs(num_outputs_)) {}

        double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
        double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
        void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
        void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }

        double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
        double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
        void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
        void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }
906
907
908
909

        unsigned long get_num_outputs (
        ) const { return num_outputs; }

910
911
912
        fc_bias_mode get_bias_mode (
        ) const { return bias_mode; }

Davis King's avatar
Davis King committed
913
914
        template <typename SUBNET>
        void setup (const SUBNET& sub)
915
916
        {
            num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
917
918
919
920
            if (bias_mode == FC_HAS_BIAS)
                params.set_size(num_inputs+1, num_outputs);
            else
                params.set_size(num_inputs, num_outputs);
921

922
            dlib::rand rnd(std::rand());
923
            randomize_parameters(params, num_inputs+num_outputs, rnd);
924
925
926
927
928
929
930
931
932

            weights = alias_tensor(num_inputs, num_outputs);

            if (bias_mode == FC_HAS_BIAS)
            {
                biases = alias_tensor(1,num_outputs);
                // set the initial bias values to zero
                biases(params,weights.size()) = 0;
            }
933
934
        }

Davis King's avatar
Davis King committed
935
936
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
937
        {
938
            output.set_size(sub.get_output().num_samples(), num_outputs);
939

940
941
942
943
944
945
946
            auto w = weights(params, 0);
            tt::gemm(0,output, 1,sub.get_output(),false, w,false);
            if (bias_mode == FC_HAS_BIAS)
            {
                auto b = biases(params, weights.size());
                tt::add(1,output,1,b);
            }
947
948
        } 

Davis King's avatar
Davis King committed
949
        template <typename SUBNET>
950
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
951
        {
952
953
            // no point computing the parameter gradients if they won't be used.
            if (learning_rate_multiplier != 0)
954
            {
955
956
957
958
959
960
961
962
963
964
                // compute the gradient of the weight parameters.  
                auto pw = weights(params_grad, 0);
                tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);

                if (bias_mode == FC_HAS_BIAS)
                {
                    // compute the gradient of the bias parameters.  
                    auto pb = biases(params_grad, weights.size());
                    tt::assign_bias_gradient(pb, gradient_input);
                }
965
            }
966
967

            // compute the gradient for the data
968
969
            auto w = weights(params, 0);
            tt::gemm(1,sub.get_gradient_input(), 1,gradient_input,false, w,true);
970
971
972
973
974
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

975
976
        friend void serialize(const fc_& item, std::ostream& out)
        {
977
            serialize("fc_2", out);
978
979
980
            serialize(item.num_outputs, out);
            serialize(item.num_inputs, out);
            serialize(item.params, out);
981
982
            serialize(item.weights, out);
            serialize(item.biases, out);
983
            serialize((int)bias_mode, out);
984
985
986
987
            serialize(item.learning_rate_multiplier, out);
            serialize(item.weight_decay_multiplier, out);
            serialize(item.bias_learning_rate_multiplier, out);
            serialize(item.bias_weight_decay_multiplier, out);
988
989
990
991
992
993
        }

        friend void deserialize(fc_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
994
            if (version != "fc_2")
995
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
996

997
998
999
            deserialize(item.num_outputs, in);
            deserialize(item.num_inputs, in);
            deserialize(item.params, in);
1000
1001
1002
1003
            deserialize(item.weights, in);
            deserialize(item.biases, in);
            int bmode = 0;
            deserialize(bmode, in);
1004
            if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
1005
1006
1007
1008
            deserialize(item.learning_rate_multiplier, in);
            deserialize(item.weight_decay_multiplier, in);
            deserialize(item.bias_learning_rate_multiplier, in);
            deserialize(item.bias_weight_decay_multiplier, in);
1009
1010
        }

1011
1012
1013
1014
1015
1016
1017
        friend std::ostream& operator<<(std::ostream& out, const fc_& item)
        {
            if (bias_mode == FC_HAS_BIAS)
            {
                out << "fc\t ("
                    << "num_outputs="<<item.num_outputs
                    << ")";
1018
1019
1020
1021
                out << " learning_rate_mult="<<item.learning_rate_multiplier;
                out << " weight_decay_mult="<<item.weight_decay_multiplier;
                out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
                out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
1022
1023
1024
1025
1026
1027
            }
            else
            {
                out << "fc_no_bias ("
                    << "num_outputs="<<item.num_outputs
                    << ")";
1028
1029
                out << " learning_rate_mult="<<item.learning_rate_multiplier;
                out << " weight_decay_mult="<<item.weight_decay_multiplier;
1030
1031
1032
1033
            }
            return out;
        }

Davis King's avatar
Davis King committed
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
        friend void to_xml(const fc_& item, std::ostream& out)
        {
            if (bias_mode==FC_HAS_BIAS)
            {
                out << "<fc"
                    << " num_outputs='"<<item.num_outputs<<"'"
                    << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
                    << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
                    << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
                    << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
                out << ">\n";
                out << mat(item.params);
                out << "</fc>\n";
            }
            else
            {
                out << "<fc_no_bias"
                    << " num_outputs='"<<item.num_outputs<<"'"
                    << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
                    << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
                out << ">\n";
                out << mat(item.params);
                out << "</fc_no_bias>\n";
            }
        }

1060
1061
1062
1063
1064
    private:

        unsigned long num_outputs;
        unsigned long num_inputs;
        resizable_tensor params;
1065
        alias_tensor weights, biases;
1066
1067
1068
1069
        double learning_rate_multiplier;
        double weight_decay_multiplier;
        double bias_learning_rate_multiplier;
        double bias_weight_decay_multiplier;
1070
1071
    };

1072
1073
1074
1075
    template <
        unsigned long num_outputs,
        typename SUBNET
        >
1076
1077
1078
1079
1080
1081
1082
    using fc = add_layer<fc_<num_outputs,FC_HAS_BIAS>, SUBNET>;

    template <
        unsigned long num_outputs,
        typename SUBNET
        >
    using fc_no_bias = add_layer<fc_<num_outputs,FC_NO_BIAS>, SUBNET>;
1083

Davis King's avatar
Davis King committed
1084
1085
1086
1087
1088
1089
1090
1091
// ----------------------------------------------------------------------------------------

    class dropout_
    {
    public:
        explicit dropout_(
            float drop_rate_ = 0.5
        ) :
1092
1093
            drop_rate(drop_rate_),
            rnd(std::rand())
Davis King's avatar
Davis King committed
1094
        {
1095
            DLIB_CASSERT(0 <= drop_rate && drop_rate <= 1);
Davis King's avatar
Davis King committed
1096
1097
1098
1099
1100
1101
        }

        // We have to add a copy constructor and assignment operator because the rnd object
        // is non-copyable.
        dropout_(
            const dropout_& item
1102
        ) : drop_rate(item.drop_rate), mask(item.mask), rnd(std::rand())
Davis King's avatar
Davis King committed
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
        {}

        dropout_& operator= (
            const dropout_& item
        )
        {
            if (this == &item)
                return *this;

            drop_rate = item.drop_rate;
            mask = item.mask;
            return *this;
        }

        float get_drop_rate (
        ) const { return drop_rate; }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            // create a random mask and use it to filter the data
            mask.copy_size(input);
            rnd.fill_uniform(mask);
            tt::threshold(mask, drop_rate);
1131
            tt::multiply(false, output, input, mask);
Davis King's avatar
Davis King committed
1132
1133
1134
1135
1136
1137
1138
1139
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& /*params_grad*/
        )
        {
1140
1141
1142
1143
            if (is_same_object(gradient_input, data_grad))
                tt::multiply(false, data_grad, mask, gradient_input);
            else
                tt::multiply(true, data_grad, mask, gradient_input);
Davis King's avatar
Davis King committed
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const dropout_& item, std::ostream& out)
        {
            serialize("dropout_", out);
            serialize(item.drop_rate, out);
            serialize(item.mask, out);
        }

        friend void deserialize(dropout_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "dropout_")
1161
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_.");
Davis King's avatar
Davis King committed
1162
1163
1164
1165
            deserialize(item.drop_rate, in);
            deserialize(item.mask, in);
        }

1166
1167
1168
1169
1170
1171
1172
1173
        friend std::ostream& operator<<(std::ostream& out, const dropout_& item)
        {
            out << "dropout\t ("
                << "drop_rate="<<item.drop_rate
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
1174
1175
1176
1177
1178
1179
1180
        friend void to_xml(const dropout_& item, std::ostream& out)
        {
            out << "<dropout"
                << " drop_rate='"<<item.drop_rate<<"'";
            out << "/>\n";
        }

Davis King's avatar
Davis King committed
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    private:
        float drop_rate;
        resizable_tensor mask;

        tt::tensor_rand rnd;
        resizable_tensor params; // unused
    };


    template <typename SUBNET>
    using dropout = add_layer<dropout_, SUBNET>;

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
// ----------------------------------------------------------------------------------------

    class multiply_
    {
    public:
        explicit multiply_(
            float val_ = 0.5
        ) :
            val(val_)
        {
        }

        multiply_ (
            const dropout_& item
        ) : val(1-item.get_drop_rate()) {}

        float get_multiply_value (
        ) const { return val; }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
1219
            tt::affine_transform(output, input, val);
1220
1221
1222
1223
1224
1225
1226
1227
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& /*params_grad*/
        )
        {
1228
1229
1230
1231
            if (is_same_object(gradient_input, data_grad))
                tt::affine_transform(data_grad, gradient_input, val);
            else
                tt::affine_transform(data_grad, data_grad, gradient_input, 1, val);
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const multiply_& item, std::ostream& out)
        {
            serialize("multiply_", out);
            serialize(item.val, out);
        }

        friend void deserialize(multiply_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version == "dropout_")
            {
                // Since we can build a multiply_ from a dropout_ we check if that's what
                // is in the stream and if so then just convert it right here.
                unserialize sin(version, in);
                dropout_ temp;
                deserialize(temp, sin);
                item = temp;
                return;
            }

            if (version != "multiply_")
1259
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_.");
1260
1261
1262
            deserialize(item.val, in);
        }

1263
1264
1265
1266
1267
1268
1269
1270
        friend std::ostream& operator<<(std::ostream& out, const multiply_& item)
        {
            out << "multiply ("
                << "val="<<item.val
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
1271
1272
1273
1274
1275
1276
        friend void to_xml(const multiply_& item, std::ostream& out)
        {
            out << "<multiply"
                << " val='"<<item.val<<"'";
            out << "/>\n";
        }
1277
1278
1279
1280
1281
1282
1283
1284
    private:
        float val;
        resizable_tensor params; // unused
    };

    template <typename SUBNET>
    using multiply = add_layer<multiply_, SUBNET>;

Davis King's avatar
Davis King committed
1285
1286
1287
1288
1289
1290
// ----------------------------------------------------------------------------------------

    class affine_
    {
    public:
        affine_(
1291
1292
1293
        ) : mode(FC_MODE)
        {
        }
Davis King's avatar
Davis King committed
1294

1295
        affine_(
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
            layer_mode mode_
        ) : mode(mode_)
        {
        }

        template <
            layer_mode bnmode
            >
        affine_(
            const bn_<bnmode>& item
1306
1307
1308
1309
        )
        {
            gamma = item.gamma;
            beta = item.beta;
1310
            mode = bnmode;
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320

            params.copy_size(item.params);

            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
            
            resizable_tensor temp(item.params);
            auto sg = gamma(temp,0);
            auto sb = beta(temp,gamma.size());

1321
            g = pointwise_multiply(mat(sg), 1.0f/sqrt(mat(item.running_variances)+item.get_eps()));
1322
1323
1324
1325
1326
            b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means));
        }

        layer_mode get_mode() const { return mode; }

Davis King's avatar
Davis King committed
1327
1328
1329
        template <typename SUBNET>
        void setup (const SUBNET& sub)
        {
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
            if (mode == FC_MODE)
            {
                gamma = alias_tensor(1,
                                sub.get_output().k(),
                                sub.get_output().nr(),
                                sub.get_output().nc());
            }
            else
            {
                gamma = alias_tensor(1, sub.get_output().k());
            }
Davis King's avatar
Davis King committed
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
            beta = gamma;

            params.set_size(gamma.size()+beta.size());

            gamma(params,0) = 1;
            beta(params,gamma.size()) = 0;
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
1353
1354
1355
1356
            if (mode == FC_MODE)
                tt::affine_transform(output, input, g, b);
            else
                tt::affine_transform_conv(output, input, g, b);
Davis King's avatar
Davis King committed
1357
1358
1359
1360
1361
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
Davis King's avatar
Davis King committed
1362
            tensor& /*params_grad*/
Davis King's avatar
Davis King committed
1363
1364
1365
1366
1367
1368
        )
        {
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());

            // We are computing the gradient of dot(gradient_input, computed_output*g + b)
1369
1370
            if (mode == FC_MODE)
            {
1371
1372
1373
1374
                if (is_same_object(gradient_input, data_grad))
                    tt::multiply(false, data_grad, gradient_input, g);
                else
                    tt::multiply(true, data_grad, gradient_input, g);
1375
1376
1377
            }
            else
            {
1378
1379
1380
1381
                if (is_same_object(gradient_input, data_grad))
                    tt::multiply_conv(false, data_grad, gradient_input, g);
                else
                    tt::multiply_conv(true, data_grad, gradient_input, g);
1382
            }
Davis King's avatar
Davis King committed
1383
1384
        }

1385
1386
        const tensor& get_layer_params() const { return empty_params; }
        tensor& get_layer_params() { return empty_params; }
Davis King's avatar
Davis King committed
1387
1388
1389
1390
1391
1392
1393

        friend void serialize(const affine_& item, std::ostream& out)
        {
            serialize("affine_", out);
            serialize(item.params, out);
            serialize(item.gamma, out);
            serialize(item.beta, out);
1394
            serialize((int)item.mode, out);
Davis King's avatar
Davis King committed
1395
1396
1397
1398
1399
1400
        }

        friend void deserialize(affine_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
1401
            if (version == "bn_con2")
1402
1403
1404
1405
1406
1407
1408
1409
1410
            {
                // Since we can build an affine_ from a bn_ we check if that's what is in
                // the stream and if so then just convert it right here.
                unserialize sin(version, in);
                bn_<CONV_MODE> temp;
                deserialize(temp, sin);
                item = temp;
                return;
            }
1411
            else if (version == "bn_fc2")
1412
1413
1414
1415
            {
                // Since we can build an affine_ from a bn_ we check if that's what is in
                // the stream and if so then just convert it right here.
                unserialize sin(version, in);
1416
                bn_<FC_MODE> temp;
1417
1418
1419
1420
1421
                deserialize(temp, sin);
                item = temp;
                return;
            }

Davis King's avatar
Davis King committed
1422
            if (version != "affine_")
1423
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
Davis King's avatar
Davis King committed
1424
1425
1426
            deserialize(item.params, in);
            deserialize(item.gamma, in);
            deserialize(item.beta, in);
1427
1428
1429
            int mode;
            deserialize(mode, in);
            item.mode = (layer_mode)mode;
Davis King's avatar
Davis King committed
1430
1431
        }

1432
1433
1434
1435
1436
1437
        friend std::ostream& operator<<(std::ostream& out, const affine_& )
        {
            out << "affine";
            return out;
        }

Davis King's avatar
Davis King committed
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
        friend void to_xml(const affine_& item, std::ostream& out)
        {
            out << "<affine";
            if (item.mode==CONV_MODE)
                out << " mode='conv'";
            else
                out << " mode='fc'";
            out << ">\n";
            out << mat(item.params);
            out << "</affine>\n";
        }

Davis King's avatar
Davis King committed
1450
    private:
1451
        resizable_tensor params, empty_params; 
Davis King's avatar
Davis King committed
1452
        alias_tensor gamma, beta;
1453
        layer_mode mode;
Davis King's avatar
Davis King committed
1454
1455
1456
    };

    template <typename SUBNET>
1457
    using affine = add_layer<affine_, SUBNET>;
Davis King's avatar
Davis King committed
1458

Davis King's avatar
Davis King committed
1459
1460
1461
1462
1463
1464
1465
1466
// ----------------------------------------------------------------------------------------

    template <
        template<typename> class tag
        >
    class add_prev_
    {
    public:
1467
1468
        const static unsigned long id = tag_id<tag>::id;

Davis King's avatar
Davis King committed
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        add_prev_() 
        {
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
1481
1482
1483
1484
1485
1486
1487
            auto&& t1 = sub.get_output();
            auto&& t2 = layer<tag>(sub).get_output();
            output.set_size(std::max(t1.num_samples(),t2.num_samples()),
                            std::max(t1.k(),t2.k()),
                            std::max(t1.nr(),t2.nr()),
                            std::max(t1.nc(),t2.nc()));
            tt::add(output, t1, t2);
Davis King's avatar
Davis King committed
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
        }

        template <typename SUBNET>
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
            // The gradient just flows backwards to the two layers that forward() added
            // together.
            tt::add(sub.get_gradient_input(), sub.get_gradient_input(), gradient_input);
            tt::add(layer<tag>(sub).get_gradient_input(), layer<tag>(sub).get_gradient_input(), gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const add_prev_& , std::ostream& out)
        {
            serialize("add_prev_", out);
        }

        friend void deserialize(add_prev_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "add_prev_")
1512
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_.");
Davis King's avatar
Davis King committed
1513
1514
        }

1515
1516
        friend std::ostream& operator<<(std::ostream& out, const add_prev_& item)
        {
1517
            out << "add_prev"<<id;
1518
1519
1520
            return out;
        }

Davis King's avatar
Davis King committed
1521
1522
1523
1524
        friend void to_xml(const add_prev_& item, std::ostream& out)
        {
            out << "<add_prev tag='"<<id<<"'/>\n";
        }
1525

Davis King's avatar
Davis King committed
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
    private:
        resizable_tensor params;
    };

    template <
        template<typename> class tag,
        typename SUBNET
        >
    using add_prev = add_layer<add_prev_<tag>, SUBNET>;

    template <typename SUBNET> using add_prev1  = add_prev<tag1, SUBNET>;
    template <typename SUBNET> using add_prev2  = add_prev<tag2, SUBNET>;
    template <typename SUBNET> using add_prev3  = add_prev<tag3, SUBNET>;
    template <typename SUBNET> using add_prev4  = add_prev<tag4, SUBNET>;
    template <typename SUBNET> using add_prev5  = add_prev<tag5, SUBNET>;
    template <typename SUBNET> using add_prev6  = add_prev<tag6, SUBNET>;
    template <typename SUBNET> using add_prev7  = add_prev<tag7, SUBNET>;
    template <typename SUBNET> using add_prev8  = add_prev<tag8, SUBNET>;
    template <typename SUBNET> using add_prev9  = add_prev<tag9, SUBNET>;
    template <typename SUBNET> using add_prev10 = add_prev<tag10, SUBNET>;

    using add_prev1_  = add_prev_<tag1>;
    using add_prev2_  = add_prev_<tag2>;
    using add_prev3_  = add_prev_<tag3>;
    using add_prev4_  = add_prev_<tag4>;
    using add_prev5_  = add_prev_<tag5>;
    using add_prev6_  = add_prev_<tag6>;
    using add_prev7_  = add_prev_<tag7>;
    using add_prev8_  = add_prev_<tag8>;
    using add_prev9_  = add_prev_<tag9>;
    using add_prev10_ = add_prev_<tag10>;

1558
1559
1560
1561
1562
1563
1564
1565
1566
// ----------------------------------------------------------------------------------------

    class relu_
    {
    public:
        relu_() 
        {
        }

Davis King's avatar
Davis King committed
1567
        template <typename SUBNET>
Davis King's avatar
Davis King committed
1568
        void setup (const SUBNET& /*sub*/)
1569
1570
1571
        {
        }

1572
        void forward_inplace(const tensor& input, tensor& output)
1573
        {
1574
            tt::relu(output, input);
1575
1576
        } 

1577
1578
1579
1580
        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
1581
            tensor& 
1582
        )
1583
        {
1584
            tt::relu_gradient(data_grad, computed_output, gradient_input);
1585
1586
1587
1588
1589
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
1590
        friend void serialize(const relu_& , std::ostream& out)
1591
        {
1592
            serialize("relu_", out);
1593
1594
        }

Davis King's avatar
Davis King committed
1595
        friend void deserialize(relu_& , std::istream& in)
1596
        {
1597
1598
1599
            std::string version;
            deserialize(version, in);
            if (version != "relu_")
1600
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
1601
1602
        }

1603
1604
1605
1606
1607
1608
        friend std::ostream& operator<<(std::ostream& out, const relu_& )
        {
            out << "relu";
            return out;
        }

Davis King's avatar
Davis King committed
1609
1610
1611
1612
        friend void to_xml(const relu_& /*item*/, std::ostream& out)
        {
            out << "<relu/>\n";
        }
1613

1614
1615
1616
1617
1618
1619
1620
1621
    private:
        resizable_tensor params;
    };


    template <typename SUBNET>
    using relu = add_layer<relu_, SUBNET>;

Davis King's avatar
Davis King committed
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
// ----------------------------------------------------------------------------------------

    class prelu_
    {
    public:
        explicit prelu_(
            float initial_param_value_ = 0.25
        ) : initial_param_value(initial_param_value_)
        {
        }

1633
1634
1635
        float get_initial_param_value (
        ) const { return initial_param_value; }

Davis King's avatar
Davis King committed
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
            params.set_size(1);
            params = initial_param_value;
        }

        template <typename SUBNET>
        void forward(
            const SUBNET& sub, 
            resizable_tensor& data_output
        )
        {
            data_output.copy_size(sub.get_output());
            tt::prelu(data_output, sub.get_output(), params);
        }

        template <typename SUBNET>
        void backward(
            const tensor& gradient_input, 
            SUBNET& sub, 
            tensor& params_grad
        )
        {
            tt::prelu_gradient(sub.get_gradient_input(), sub.get_output(), 
                gradient_input, params, params_grad);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const prelu_& item, std::ostream& out)
        {
            serialize("prelu_", out);
            serialize(item.params, out);
            serialize(item.initial_param_value, out);
        }

        friend void deserialize(prelu_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "prelu_")
1679
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_.");
Davis King's avatar
Davis King committed
1680
1681
1682
1683
            deserialize(item.params, in);
            deserialize(item.initial_param_value, in);
        }

1684
1685
1686
1687
1688
1689
1690
1691
        friend std::ostream& operator<<(std::ostream& out, const prelu_& item)
        {
            out << "prelu\t ("
                << "initial_param_value="<<item.initial_param_value
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
1692
1693
1694
1695
1696
1697
1698
        friend void to_xml(const prelu_& item, std::ostream& out)
        {
            out << "<prelu initial_param_value='"<<item.initial_param_value<<"'>\n";
            out << mat(item.params);
            out << "</prelu>\n";
        }

Davis King's avatar
Davis King committed
1699
1700
1701
1702
1703
1704
1705
1706
    private:
        resizable_tensor params;
        float initial_param_value;
    };

    template <typename SUBNET>
    using prelu = add_layer<prelu_, SUBNET>;

1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
// ----------------------------------------------------------------------------------------

    class sig_
    {
    public:
        sig_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1717
        void setup (const SUBNET& /*sub*/)
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::sigmoid(output, input);
        } 

        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::sigmoid_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const sig_& , std::ostream& out)
        {
            serialize("sig_", out);
        }

        friend void deserialize(sig_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "sig_")
1749
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_.");
1750
        }
1751

1752
1753
1754
1755
1756
1757
        friend std::ostream& operator<<(std::ostream& out, const sig_& )
        {
            out << "sig";
            return out;
        }

Davis King's avatar
Davis King committed
1758
1759
1760
1761
1762
        friend void to_xml(const sig_& /*item*/, std::ostream& out)
        {
            out << "<sig/>\n";
        }

1763

1764
    private:
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
        resizable_tensor params;
    };


    template <typename SUBNET>
    using sig = add_layer<sig_, SUBNET>;

// ----------------------------------------------------------------------------------------

    class htan_
    {
    public:
        htan_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1782
        void setup (const SUBNET& /*sub*/)
1783
1784
1785
1786
1787
1788
1789
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::tanh(output, input);
        } 
1790

1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::tanh_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const htan_& , std::ostream& out)
        {
            serialize("htan_", out);
        }

        friend void deserialize(htan_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "htan_")
1814
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_.");
1815
1816
        }

1817
1818
1819
1820
1821
1822
        friend std::ostream& operator<<(std::ostream& out, const htan_& )
        {
            out << "htan";
            return out;
        }

Davis King's avatar
Davis King committed
1823
1824
1825
1826
1827
        friend void to_xml(const htan_& /*item*/, std::ostream& out)
        {
            out << "<htan/>\n";
        }

1828

1829
    private:
1830
1831
1832
        resizable_tensor params;
    };

1833

Davis King's avatar
Davis King committed
1834
    template <typename SUBNET>
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
    using htan = add_layer<htan_, SUBNET>;

// ----------------------------------------------------------------------------------------

    class softmax_
    {
    public:
        softmax_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1847
        void setup (const SUBNET& /*sub*/)
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::softmax(output, input);
        } 

        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::softmax_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const softmax_& , std::ostream& out)
        {
            serialize("softmax_", out);
        }

        friend void deserialize(softmax_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "softmax_")
1879
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_.");
1880
1881
        }

1882
1883
1884
1885
1886
1887
        friend std::ostream& operator<<(std::ostream& out, const softmax_& )
        {
            out << "softmax";
            return out;
        }

Davis King's avatar
Davis King committed
1888
1889
1890
1891
1892
        friend void to_xml(const softmax_& /*item*/, std::ostream& out)
        {
            out << "<softmax/>\n";
        }

1893
1894
1895
1896
1897
1898
    private:
        resizable_tensor params;
    };

    template <typename SUBNET>
    using softmax = add_layer<softmax_, SUBNET>;
1899

Fm's avatar
Fm committed
1900
// ----------------------------------------------------------------------------------------
Fm's avatar
Fm committed
1901
1902
    namespace impl
    {
1903
1904
1905
1906
1907
1908
1909
1910
        template <template<typename> class TAG_TYPE, template<typename> class... TAG_TYPES>
        struct concat_helper_impl{

            constexpr static size_t tag_count() {return 1 + concat_helper_impl<TAG_TYPES...>::tag_count();}
            static void list_tags(std::ostream& out)
            {
                out << tag_id<TAG_TYPE>::id << (tag_count() > 1 ? "," : "");
                concat_helper_impl<TAG_TYPES...>::list_tags(out);
Davis King's avatar
Davis King committed
1911
1912
            }

Fm's avatar
Fm committed
1913
1914
1915
1916
            template<typename SUBNET>
            static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
            {
                auto& t = layer<TAG_TYPE>(sub).get_output();
1917
                concat_helper_impl<TAG_TYPES...>::resize_out(out, sub, sum_k + t.k());
Fm's avatar
Fm committed
1918
1919
1920
1921
1922
1923
            }
            template<typename SUBNET>
            static void concat(tensor& out, const SUBNET& sub, size_t k_offset)
            {
                auto& t = layer<TAG_TYPE>(sub).get_output();
                tt::copy_tensor(out, k_offset, t, 0, t.k());
1924
1925
                k_offset += t.k();
                concat_helper_impl<TAG_TYPES...>::concat(out, sub, k_offset);
Fm's avatar
Fm committed
1926
1927
1928
1929
1930
1931
            }
            template<typename SUBNET>
            static void split(const tensor& input, SUBNET& sub, size_t k_offset)
            {
                auto& t = layer<TAG_TYPE>(sub).get_gradient_input();
                tt::copy_tensor(t, 0, input, k_offset, t.k());
1932
1933
                k_offset += t.k();
                concat_helper_impl<TAG_TYPES...>::split(input, sub, k_offset);
Fm's avatar
Fm committed
1934
1935
            }
        };
1936
1937
1938
        template <template<typename> class TAG_TYPE>
        struct concat_helper_impl<TAG_TYPE>{
            constexpr static size_t tag_count() {return 1;}
Davis King's avatar
Davis King committed
1939
1940
            static void list_tags(std::ostream& out) 
            { 
1941
                out << tag_id<TAG_TYPE>::id;
Davis King's avatar
Davis King committed
1942
            }
1943

Fm's avatar
Fm committed
1944
1945
1946
1947
            template<typename SUBNET>
            static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
            {
                auto& t = layer<TAG_TYPE>(sub).get_output();
1948
                out.set_size(t.num_samples(), t.k() + sum_k, t.nr(), t.nc());
Fm's avatar
Fm committed
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
            }
            template<typename SUBNET>
            static void concat(tensor& out, const SUBNET& sub, size_t k_offset)
            {
                auto& t = layer<TAG_TYPE>(sub).get_output();
                tt::copy_tensor(out, k_offset, t, 0, t.k());
            }
            template<typename SUBNET>
            static void split(const tensor& input, SUBNET& sub, size_t k_offset)
            {
                auto& t = layer<TAG_TYPE>(sub).get_gradient_input();
                tt::copy_tensor(t, 0, input, k_offset, t.k());
            }
        };
    }
    // concat layer
    template<
        template<typename> class... TAG_TYPES
        >
    class concat_
    {
Davis King's avatar
Davis King committed
1970
1971
        static void list_tags(std::ostream& out) { impl::concat_helper_impl<TAG_TYPES...>::list_tags(out);};

Fm's avatar
Fm committed
1972
    public:
1973
1974
        constexpr static size_t tag_count() {return impl::concat_helper_impl<TAG_TYPES...>::tag_count();};

Fm's avatar
Fm committed
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
        template <typename SUBNET>
        void setup (const SUBNET&)
        {
            // do nothing
        }
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
            // the total depth of result is the sum of depths from all tags
            impl::concat_helper_impl<TAG_TYPES...>::resize_out(output, sub, 0);

            // copy output from each tag into different part result
            impl::concat_helper_impl<TAG_TYPES...>::concat(output, sub, 0);
        }

        template <typename SUBNET>
        void backward(const tensor& gradient_input, SUBNET& sub, tensor&)
        {
Davis King's avatar
Davis King committed
1993
            // Gradient is split into parts for each tag layer
Fm's avatar
Fm committed
1994
1995
1996
1997
1998
1999
2000
2001
2002
            impl::concat_helper_impl<TAG_TYPES...>::split(gradient_input, sub, 0);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const concat_& item, std::ostream& out)
        {
            serialize("concat_", out);
2003
2004
            size_t count = tag_count();
            serialize(count, out);
Fm's avatar
Fm committed
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
        }

        friend void deserialize(concat_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "concat_")
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_.");
            size_t count_tags;
            deserialize(count_tags, in);
2015
            if (count_tags != tag_count())
Fm's avatar
Fm committed
2016
                throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " +
2017
2018
                                          std::to_string(tag_count()) +
                                                  " found while deserializing dlib::concat_.");
Fm's avatar
Fm committed
2019
2020
2021
2022
        }

        friend std::ostream& operator<<(std::ostream& out, const concat_& item)
        {
Davis King's avatar
Davis King committed
2023
2024
2025
            out << "concat\t (";
            list_tags(out);
            out << ")";
Fm's avatar
Fm committed
2026
2027
2028
            return out;
        }

Davis King's avatar
Davis King committed
2029
2030
2031
2032
2033
2034
2035
        friend void to_xml(const concat_& item, std::ostream& out)
        {
            out << "<concat tags='";
            list_tags(out);
            out << "'/>\n";
        }

Fm's avatar
Fm committed
2036
2037
2038
2039
2040
    private:
        resizable_tensor params; // unused
    };


2041
2042
2043
2044
2045
    // concat layer definitions
    template <template<typename> class TAG1,
            template<typename> class TAG2,
            typename SUBNET>
    using concat2 = add_layer<concat_<TAG1, TAG2>, SUBNET>;
2046

2047
2048
2049
2050
2051
    template <template<typename> class TAG1,
            template<typename> class TAG2,
            template<typename> class TAG3,
            typename SUBNET>
    using concat3 = add_layer<concat_<TAG1, TAG2, TAG3>, SUBNET>;
2052

2053
2054
2055
2056
2057
2058
    template <template<typename> class TAG1,
            template<typename> class TAG2,
            template<typename> class TAG3,
            template<typename> class TAG4,
            typename SUBNET>
    using concat4 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4>, SUBNET>;
2059

2060
2061
2062
2063
2064
2065
2066
    template <template<typename> class TAG1,
            template<typename> class TAG2,
            template<typename> class TAG3,
            template<typename> class TAG4,
            template<typename> class TAG5,
            typename SUBNET>
    using concat5 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4, TAG5>, SUBNET>;
Fm's avatar
Fm committed
2067

Davis King's avatar
Davis King committed
2068
2069
    // inception layer will use tags internally. If user will use tags too, some conflicts
    // possible to exclude them, here are new tags specially for inceptions
Fm's avatar
Fm committed
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
    template <typename SUBNET> using itag0  = add_tag_layer< 1000 + 0, SUBNET>;
    template <typename SUBNET> using itag1  = add_tag_layer< 1000 + 1, SUBNET>;
    template <typename SUBNET> using itag2  = add_tag_layer< 1000 + 2, SUBNET>;
    template <typename SUBNET> using itag3  = add_tag_layer< 1000 + 3, SUBNET>;
    template <typename SUBNET> using itag4  = add_tag_layer< 1000 + 4, SUBNET>;
    template <typename SUBNET> using itag5  = add_tag_layer< 1000 + 5, SUBNET>;
    // skip to inception input
    template <typename SUBNET> using iskip  = add_skip_layer< itag0, SUBNET>;

    // here are some templates to be used for creating inception layer groups
    template <template<typename>class B1,
            template<typename>class B2,
            typename SUBNET>
2083
    using inception2 = concat2<itag1, itag2, itag1<B1<iskip< itag2<B2< itag0<SUBNET>>>>>>>;
2084

Fm's avatar
Fm committed
2085
2086
2087
2088
    template <template<typename>class B1,
            template<typename>class B2,
            template<typename>class B3,
            typename SUBNET>
2089
    using inception3 = concat3<itag1, itag2, itag3, itag1<B1<iskip< itag2<B2<iskip< itag3<B3<  itag0<SUBNET>>>>>>>>>>;
2090

Fm's avatar
Fm committed
2091
2092
2093
2094
2095
    template <template<typename>class B1,
            template<typename>class B2,
            template<typename>class B3,
            template<typename>class B4,
            typename SUBNET>
2096
    using inception4 = concat4<itag1, itag2, itag3, itag4,
2097
2098
                itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip<  itag4<B4<  itag0<SUBNET>>>>>>>>>>>>>;

Fm's avatar
Fm committed
2099
2100
2101
2102
2103
2104
    template <template<typename>class B1,
            template<typename>class B2,
            template<typename>class B3,
            template<typename>class B4,
            template<typename>class B5,
            typename SUBNET>
2105
    using inception5 = concat5<itag1, itag2, itag3, itag4, itag5,
Davis King's avatar
Davis King committed
2106
                itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip<  itag4<B4<iskip<  itag5<B5<  itag0<SUBNET>>>>>>>>>>>>>>>>;
2107
2108
2109
2110
// ----------------------------------------------------------------------------------------

}

2111
#endif // DLIB_DNn_LAYERS_H_
2112
2113