"vscode:/vscode.git/clone" did not exist on "d8a5d96b981bf6e1c5a61fde18acaeed0fb89f7c"
layers.h 46.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNn_LAYERS_H_
#define DLIB_DNn_LAYERS_H_

#include "layers_abstract.h"
#include "tensor.h"
#include "core.h"
#include <iostream>
#include <string>
Davis King's avatar
Davis King committed
11
12
#include "../rand.h"
#include "../string.h"
13
#include "tensor_tools.h"
14
#include "../vectorstream.h"
15
16
17
18
19
20
21


namespace dlib
{

// ----------------------------------------------------------------------------------------

22
23
24
25
26
27
28
    template <
        long _num_filters,
        long _nr,
        long _nc,
        int _stride_y,
        int _stride_x
        >
29
30
31
    class con_
    {
    public:
32

33
34
35
36
37
        static_assert(_num_filters > 0, "The number of filters must be > 0");
        static_assert(_nr > 0, "The number of rows in a filter must be > 0");
        static_assert(_nc > 0, "The number of columns in a filter must be > 0");
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
38

Davis King's avatar
Davis King committed
39
        con_(
40
        )  
41
42
        {}

43
44
45
46
47
48
        long num_filters() const { return _num_filters; }
        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }

Davis King's avatar
Davis King committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        con_ (
            const con_& item
        ) : 
            params(item.params),
            filters(item.filters),
            biases(item.biases)
        {
            // this->conv is non-copyable and basically stateless, so we have to write our
            // own copy to avoid trying to copy it and getting an error.
        }

        con_& operator= (
            const con_& item
        )
        {
            if (this == &item)
                return *this;

            // this->conv is non-copyable and basically stateless, so we have to write our
            // own copy to avoid trying to copy it and getting an error.
            params = item.params;
            filters = item.filters;
            biases = item.biases;
            return *this;
        }

Davis King's avatar
Davis King committed
75
76
        template <typename SUBNET>
        void setup (const SUBNET& sub)
77
        {
78
79
            long num_inputs = _nr*_nc*sub.get_output().k();
            long num_outputs = _num_filters;
Davis King's avatar
Davis King committed
80
            // allocate params for the filters and also for the filter bias values.
81
            params.set_size(num_inputs*_num_filters + _num_filters);
Davis King's avatar
Davis King committed
82

83
            dlib::rand rnd(std::rand());
Davis King's avatar
Davis King committed
84
85
            randomize_parameters(params, num_inputs+num_outputs, rnd);

86
87
            filters = alias_tensor(_num_filters, sub.get_output().k(), _nr, _nc);
            biases = alias_tensor(1,_num_filters);
Davis King's avatar
Davis King committed
88
89
90

            // set the initial bias values to zero
            biases(params,filters.size()) = 0;
91
92
        }

Davis King's avatar
Davis King committed
93
94
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
95
        {
Davis King's avatar
Davis King committed
96
97
98
            conv(output,
                sub.get_output(),
                filters(params,0),
99
100
                _stride_y,
                _stride_x);
Davis King's avatar
Davis King committed
101
102

            tt::add(1,output,1,biases(params,filters.size()));
103
104
        } 

Davis King's avatar
Davis King committed
105
        template <typename SUBNET>
106
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
107
        {
Davis King's avatar
Davis King committed
108
109
110
111
            conv.get_gradient_for_data (gradient_input, filters(params,0), sub.get_gradient_input());
            auto filt = filters(params_grad,0);
            conv.get_gradient_for_filters (gradient_input, sub.get_output(), filt);
            auto b = biases(params_grad, filters.size());
112
            tt::assign_conv_bias_gradient(b, gradient_input);
113
114
115
116
117
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
118
119
120
121
        friend void serialize(const con_& item, std::ostream& out)
        {
            serialize("con_", out);
            serialize(item.params, out);
122
123
124
125
126
            serialize(_num_filters, out);
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
Davis King's avatar
Davis King committed
127
128
129
130
131
132
133
134
135
            serialize(item.filters, out);
            serialize(item.biases, out);
        }

        friend void deserialize(con_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "con_")
136
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
Davis King's avatar
Davis King committed
137
            deserialize(item.params, in);
138
139
140
141
142
143
144
145
146
147
148
149


            long num_filters;
            long nr;
            long nc;
            int stride_y;
            int stride_x;
            deserialize(num_filters, in);
            deserialize(nr, in);
            deserialize(nc, in);
            deserialize(stride_y, in);
            deserialize(stride_x, in);
Davis King's avatar
Davis King committed
150
151
            deserialize(item.filters, in);
            deserialize(item.biases, in);
152
153
154
155
156
157

            if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
            if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
            if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
            if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
            if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
Davis King's avatar
Davis King committed
158
159
        }

160
161
162
163
164
165
166
167
168
169
170
171
172
173

        friend std::ostream& operator<<(std::ostream& out, const con_& item)
        {
            out << "con\t ("
                << "num_filters="<<_num_filters
                << ", nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
                << ")";
            return out;
        }


174
175
176
    private:

        resizable_tensor params;
Davis King's avatar
Davis King committed
177
178
179
180
        alias_tensor filters, biases;

        tt::tensor_conv conv;

181
182
    };

183
184
185
186
187
188
189
190
191
    template <
        long num_filters,
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
192

Davis King's avatar
Davis King committed
193
194
// ----------------------------------------------------------------------------------------

195
196
197
198
199
200
    template <
        long _nr,
        long _nc,
        int _stride_y,
        int _stride_x
        >
Davis King's avatar
Davis King committed
201
202
    class max_pool_
    {
203
204
205
206
        static_assert(_nr > 0, "The number of rows in a filter must be > 0");
        static_assert(_nc > 0, "The number of columns in a filter must be > 0");
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
Davis King's avatar
Davis King committed
207
208
209
210
    public:


        max_pool_(
211
        ) {}
Davis King's avatar
Davis King committed
212
213
214
215
216
217
218

        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }

        max_pool_ (
219
220
            const max_pool_& 
        )  
Davis King's avatar
Davis King committed
221
222
223
        {
            // this->mp is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
224
            mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
Davis King's avatar
Davis King committed
225
226
227
228
229
230
231
232
233
234
235
        }

        max_pool_& operator= (
            const max_pool_& item
        )
        {
            if (this == &item)
                return *this;

            // this->mp is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
236
            mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
Davis King's avatar
Davis King committed
237
238
239
240
241
242
            return *this;
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
243
            mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
Davis King's avatar
Davis King committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
            mp(output, sub.get_output());
        } 

        template <typename SUBNET>
        void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
            mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const max_pool_& item, std::ostream& out)
        {
            serialize("max_pool_", out);
264
265
266
267
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
Davis King's avatar
Davis King committed
268
269
270
271
272
273
274
        }

        friend void deserialize(max_pool_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "max_pool_")
275
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
Davis King's avatar
Davis King committed
276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);

            long nr;
            long nc;
            int stride_y;
            int stride_x;

            deserialize(nr, in);
            deserialize(nc, in);
            deserialize(stride_y, in);
            deserialize(stride_x, in);
            if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
            if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
            if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
            if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
Davis King's avatar
Davis King committed
292
293
        }

294
295
296
297
298
299
300
301
302
303
304
305
        friend std::ostream& operator<<(std::ostream& out, const max_pool_& item)
        {
            out << "max_pool ("
                << "nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
                << ")";
            return out;
        }


Davis King's avatar
Davis King committed
306
307
308
    private:


309
        tt::pooling mp;
Davis King's avatar
Davis King committed
310
311
312
        resizable_tensor params;
    };

313
314
315
316
317
318
319
320
    template <
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
Davis King's avatar
Davis King committed
321

322
323
// ----------------------------------------------------------------------------------------

324
325
326
327
328
329
    template <
        long _nr,
        long _nc,
        int _stride_y,
        int _stride_x
        >
330
331
332
    class avg_pool_
    {
    public:
333
334
335
336
        static_assert(_nr > 0, "The number of rows in a filter must be > 0");
        static_assert(_nc > 0, "The number of columns in a filter must be > 0");
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
337
338

        avg_pool_(
339
        ) {}
340
341
342
343
344
345
346

        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }

        avg_pool_ (
347
348
            const avg_pool_& 
        )  
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        {
            // this->ap is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
            ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
        }

        avg_pool_& operator= (
            const avg_pool_& item
        )
        {
            if (this == &item)
                return *this;

            // this->ap is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
            ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
            return *this;
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
            ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
            ap(output, sub.get_output());
        } 

        template <typename SUBNET>
        void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
            ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const avg_pool_& item, std::ostream& out)
        {
            serialize("avg_pool_", out);
392
393
394
395
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
396
397
398
399
400
401
402
        }

        friend void deserialize(avg_pool_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "avg_pool_")
403
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
404

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
            item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);

            long nr;
            long nc;
            int stride_y;
            int stride_x;

            deserialize(nr, in);
            deserialize(nc, in);
            deserialize(stride_y, in);
            deserialize(stride_x, in);
            if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
            if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
            if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
            if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
420
421
        }

422
423
424
425
426
427
428
429
430
431
        friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item)
        {
            out << "avg_pool ("
                << "nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
                << ")";
            return out;
        }
432
433
434
435
436
437
    private:

        tt::pooling ap;
        resizable_tensor params;
    };

438
439
440
441
442
443
444
445
    template <
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
446

447
448
// ----------------------------------------------------------------------------------------

449
    enum layer_mode
450
    {
451
452
        CONV_MODE = 0,
        FC_MODE = 1
453
454
    };

455
456
457
    template <
        layer_mode mode
        >
458
459
460
    class bn_
    {
    public:
461
        bn_() : num_updates(0), running_stats_window_size(1000)
462
463
        {}

464
        explicit bn_(unsigned long window_size) : num_updates(0), running_stats_window_size(window_size)
465
466
467
468
        {}

        layer_mode get_mode() const { return mode; }
        unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
469

470
471
472
        template <typename SUBNET>
        void setup (const SUBNET& sub)
        {
473
            if (mode == FC_MODE)
474
475
476
477
478
479
480
481
482
483
            {
                gamma = alias_tensor(1,
                                sub.get_output().k(),
                                sub.get_output().nr(),
                                sub.get_output().nc());
            }
            else
            {
                gamma = alias_tensor(1, sub.get_output().k());
            }
Davis King's avatar
Davis King committed
484
485
486
487
488
489
            beta = gamma;

            params.set_size(gamma.size()+beta.size());

            gamma(params,0) = 1;
            beta(params,gamma.size()) = 0;
490

491
            running_means.copy_size(gamma(params,0));
492
            running_variances.copy_size(gamma(params,0));
493
            running_means = 0;
494
            running_variances = 1;
495
            num_updates = 0;
496
497
498
499
500
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
Davis King's avatar
Davis King committed
501
502
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
503
504
            if (sub.get_output().num_samples() > 1)
            {
505
                const double decay = 1.0 - num_updates/(num_updates+1.0);
506
507
                if (num_updates <running_stats_window_size)
                    ++num_updates;
508
                if (mode == FC_MODE)
509
                    tt::batch_normalize(output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
510
                else 
511
                    tt::batch_normalize_conv(output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
512
513
514
            }
            else // we are running in testing mode so we just linearly scale the input tensor.
            {
515
                if (mode == FC_MODE)
516
                    tt::batch_normalize_inference(output, sub.get_output(), g, b, running_means, running_variances);
517
                else
518
                    tt::batch_normalize_conv_inference(output, sub.get_output(), g, b, running_means, running_variances);
519
            }
520
521
522
523
524
        } 

        template <typename SUBNET>
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
        {
Davis King's avatar
Davis King committed
525
526
527
            auto g = gamma(params,0);
            auto g_grad = gamma(params_grad, 0);
            auto b_grad = beta(params_grad, gamma.size());
528
            if (mode == FC_MODE)
529
530
531
                tt::batch_normalize_gradient(gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
            else
                tt::batch_normalize_conv_gradient(gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
532
533
534
535
536
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
537
538
        friend void serialize(const bn_& item, std::ostream& out)
        {
539
540
541
542
            if (mode == CONV_MODE)
                serialize("bn_con", out);
            else // if FC_MODE
                serialize("bn_fc", out);
Davis King's avatar
Davis King committed
543
544
545
546
547
            serialize(item.params, out);
            serialize(item.gamma, out);
            serialize(item.beta, out);
            serialize(item.means, out);
            serialize(item.invstds, out);
548
            serialize(item.running_means, out);
549
            serialize(item.running_variances, out);
550
551
            serialize(item.num_updates, out);
            serialize(item.running_stats_window_size, out);
Davis King's avatar
Davis King committed
552
553
554
555
556
557
558
        }

        friend void deserialize(bn_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "bn_")
559
560
561
562
563
564
565
566
567
568
569
570
571
            {
                if (mode == CONV_MODE) 
                {
                    if (version != "bn_con")
                        throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
                }
                else // must be in FC_MODE
                {
                    if (version != "bn_fc")
                        throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
                }
            }

Davis King's avatar
Davis King committed
572
573
574
575
576
            deserialize(item.params, in);
            deserialize(item.gamma, in);
            deserialize(item.beta, in);
            deserialize(item.means, in);
            deserialize(item.invstds, in);
577
            deserialize(item.running_means, in);
578
            deserialize(item.running_variances, in);
579
580
            deserialize(item.num_updates, in);
            deserialize(item.running_stats_window_size, in);
581
582
583
584
585
586
587
588
589

            // if this is the older "bn_" version then check its saved mode value and make
            // sure it is the one we are really using.  
            if (version == "bn_")
            {
                int _mode;
                deserialize(_mode, in);
                if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_");

590
                // We also need to flip the running_variances around since the previous
591
                // format saved the inverse standard deviations instead of variances.
592
                item.running_variances = 1.0f/squared(mat(item.running_variances)) - tt::BATCH_NORM_EPS;
593
            }
Davis King's avatar
Davis King committed
594
595
        }

596
597
598
599
600
601
602
603
604
        friend std::ostream& operator<<(std::ostream& out, const bn_& item)
        {
            if (mode == CONV_MODE)
                out << "bn_con";
            else
                out << "bn_fc";
            return out;
        }

605
606
    private:

607
608
        friend class affine_;

609
        resizable_tensor params;
Davis King's avatar
Davis King committed
610
        alias_tensor gamma, beta;
611
        resizable_tensor means, running_means;
612
        resizable_tensor invstds, running_variances;
613
614
        unsigned long num_updates;
        unsigned long running_stats_window_size;
615
616
617
    };

    template <typename SUBNET>
618
619
620
    using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
    template <typename SUBNET>
    using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
621

622
623
// ----------------------------------------------------------------------------------------

624
625
    enum fc_bias_mode
    {
626
627
628
629
        FC_HAS_BIAS = 0,
        FC_NO_BIAS = 1
    };

630
631
632
633
634
635
636
637
638
639
    struct num_fc_outputs
    {
        num_fc_outputs(unsigned long n) : num_outputs(n) {}
        unsigned long num_outputs;
    };

    template <
        unsigned long num_outputs_,
        fc_bias_mode bias_mode
        >
640
641
    class fc_
    {
642
643
        static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");

644
    public:
645
        fc_() : num_outputs(num_outputs_), num_inputs(0)
646
647
648
        {
        }

649
        fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0) {}
650
651
652
653

        unsigned long get_num_outputs (
        ) const { return num_outputs; }

654
655
656
        fc_bias_mode get_bias_mode (
        ) const { return bias_mode; }

Davis King's avatar
Davis King committed
657
658
        template <typename SUBNET>
        void setup (const SUBNET& sub)
659
660
        {
            num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
661
662
663
664
            if (bias_mode == FC_HAS_BIAS)
                params.set_size(num_inputs+1, num_outputs);
            else
                params.set_size(num_inputs, num_outputs);
665

666
            dlib::rand rnd(std::rand());
667
            randomize_parameters(params, num_inputs+num_outputs, rnd);
668
669
670
671
672
673
674
675
676

            weights = alias_tensor(num_inputs, num_outputs);

            if (bias_mode == FC_HAS_BIAS)
            {
                biases = alias_tensor(1,num_outputs);
                // set the initial bias values to zero
                biases(params,weights.size()) = 0;
            }
677
678
        }

Davis King's avatar
Davis King committed
679
680
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
681
        {
682
            output.set_size(sub.get_output().num_samples(), num_outputs);
683

684
685
686
687
688
689
690
            auto w = weights(params, 0);
            tt::gemm(0,output, 1,sub.get_output(),false, w,false);
            if (bias_mode == FC_HAS_BIAS)
            {
                auto b = biases(params, weights.size());
                tt::add(1,output,1,b);
            }
691
692
        } 

Davis King's avatar
Davis King committed
693
        template <typename SUBNET>
694
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
695
        {
696
697
698
699
700
701
702
703
            // compute the gradient of the weight parameters.  
            auto pw = weights(params_grad, 0);
            tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);

            if (bias_mode == FC_HAS_BIAS)
            {
                // compute the gradient of the bias parameters.  
                auto pb = biases(params_grad, weights.size());
704
                tt::assign_bias_gradient(pb, gradient_input);
705
            }
706
707

            // compute the gradient for the data
708
709
            auto w = weights(params, 0);
            tt::gemm(1,sub.get_gradient_input(), 1,gradient_input,false, w,true);
710
711
712
713
714
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

715
716
717
718
719
720
        friend void serialize(const fc_& item, std::ostream& out)
        {
            serialize("fc_", out);
            serialize(item.num_outputs, out);
            serialize(item.num_inputs, out);
            serialize(item.params, out);
721
722
            serialize(item.weights, out);
            serialize(item.biases, out);
723
            serialize((int)bias_mode, out);
724
725
726
727
728
729
730
        }

        friend void deserialize(fc_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "fc_")
731
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
732

733
734
735
            deserialize(item.num_outputs, in);
            deserialize(item.num_inputs, in);
            deserialize(item.params, in);
736
737
738
739
            deserialize(item.weights, in);
            deserialize(item.biases, in);
            int bmode = 0;
            deserialize(bmode, in);
740
            if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
741
742
        }

743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        friend std::ostream& operator<<(std::ostream& out, const fc_& item)
        {
            if (bias_mode == FC_HAS_BIAS)
            {
                out << "fc\t ("
                    << "num_outputs="<<item.num_outputs
                    << ")";
            }
            else
            {
                out << "fc_no_bias ("
                    << "num_outputs="<<item.num_outputs
                    << ")";
            }
            return out;
        }

760
761
762
763
764
    private:

        unsigned long num_outputs;
        unsigned long num_inputs;
        resizable_tensor params;
765
        alias_tensor weights, biases;
766
767
    };

768
769
770
771
    template <
        unsigned long num_outputs,
        typename SUBNET
        >
772
773
774
775
776
777
778
    using fc = add_layer<fc_<num_outputs,FC_HAS_BIAS>, SUBNET>;

    template <
        unsigned long num_outputs,
        typename SUBNET
        >
    using fc_no_bias = add_layer<fc_<num_outputs,FC_NO_BIAS>, SUBNET>;
779

Davis King's avatar
Davis King committed
780
781
782
783
784
785
786
787
788
789
// ----------------------------------------------------------------------------------------

    class dropout_
    {
    public:
        explicit dropout_(
            float drop_rate_ = 0.5
        ) :
            drop_rate(drop_rate_)
        {
790
            DLIB_CASSERT(0 <= drop_rate && drop_rate <= 1,"");
Davis King's avatar
Davis King committed
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
        }

        // We have to add a copy constructor and assignment operator because the rnd object
        // is non-copyable.
        dropout_(
            const dropout_& item
        ) : drop_rate(item.drop_rate), mask(item.mask)
        {}

        dropout_& operator= (
            const dropout_& item
        )
        {
            if (this == &item)
                return *this;

            drop_rate = item.drop_rate;
            mask = item.mask;
            return *this;
        }

        float get_drop_rate (
        ) const { return drop_rate; }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            // create a random mask and use it to filter the data
            mask.copy_size(input);
            rnd.fill_uniform(mask);
            tt::threshold(mask, drop_rate);
            tt::multiply(output, input, mask);
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& /*params_grad*/
        )
        {
            tt::multiply(data_grad, mask, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const dropout_& item, std::ostream& out)
        {
            serialize("dropout_", out);
            serialize(item.drop_rate, out);
            serialize(item.mask, out);
        }

        friend void deserialize(dropout_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "dropout_")
853
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_.");
Davis King's avatar
Davis King committed
854
855
856
857
            deserialize(item.drop_rate, in);
            deserialize(item.mask, in);
        }

858
859
860
861
862
863
864
865
        friend std::ostream& operator<<(std::ostream& out, const dropout_& item)
        {
            out << "dropout\t ("
                << "drop_rate="<<item.drop_rate
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
866
867
868
869
870
871
872
873
874
875
876
877
    private:
        float drop_rate;
        resizable_tensor mask;

        tt::tensor_rand rnd;
        resizable_tensor params; // unused
    };


    template <typename SUBNET>
    using dropout = add_layer<dropout_, SUBNET>;

878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
// ----------------------------------------------------------------------------------------

    class multiply_
    {
    public:
        explicit multiply_(
            float val_ = 0.5
        ) :
            val(val_)
        {
        }

        multiply_ (
            const dropout_& item
        ) : val(1-item.get_drop_rate()) {}

        float get_multiply_value (
        ) const { return val; }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::affine_transform(output, input, val, 0);
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& /*params_grad*/
        )
        {
            tt::affine_transform(data_grad, gradient_input, val, 0);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const multiply_& item, std::ostream& out)
        {
            serialize("multiply_", out);
            serialize(item.val, out);
        }

        friend void deserialize(multiply_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version == "dropout_")
            {
                // Since we can build a multiply_ from a dropout_ we check if that's what
                // is in the stream and if so then just convert it right here.
                unserialize sin(version, in);
                dropout_ temp;
                deserialize(temp, sin);
                item = temp;
                return;
            }

            if (version != "multiply_")
941
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_.");
942
943
944
            deserialize(item.val, in);
        }

945
946
947
948
949
950
951
952
        friend std::ostream& operator<<(std::ostream& out, const multiply_& item)
        {
            out << "multiply ("
                << "val="<<item.val
                << ")";
            return out;
        }

953
954
955
956
957
958
959
960
    private:
        float val;
        resizable_tensor params; // unused
    };

    template <typename SUBNET>
    using multiply = add_layer<multiply_, SUBNET>;

Davis King's avatar
Davis King committed
961
962
963
964
965
966
// ----------------------------------------------------------------------------------------

    class affine_
    {
    public:
        affine_(
967
968
969
        ) : mode(FC_MODE)
        {
        }
Davis King's avatar
Davis King committed
970

971
        affine_(
972
973
974
975
976
977
978
979
980
981
            layer_mode mode_
        ) : mode(mode_)
        {
        }

        template <
            layer_mode bnmode
            >
        affine_(
            const bn_<bnmode>& item
982
983
984
985
        )
        {
            gamma = item.gamma;
            beta = item.beta;
986
            mode = bnmode;
987
988
989
990
991
992
993
994
995
996

            params.copy_size(item.params);

            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
            
            resizable_tensor temp(item.params);
            auto sg = gamma(temp,0);
            auto sb = beta(temp,gamma.size());

997
            g = pointwise_multiply(mat(sg), 1.0f/sqrt(mat(item.running_variances)+tt::BATCH_NORM_EPS));
998
999
1000
1001
1002
            b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means));
        }

        layer_mode get_mode() const { return mode; }

Davis King's avatar
Davis King committed
1003
1004
1005
        template <typename SUBNET>
        void setup (const SUBNET& sub)
        {
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            if (mode == FC_MODE)
            {
                gamma = alias_tensor(1,
                                sub.get_output().k(),
                                sub.get_output().nr(),
                                sub.get_output().nc());
            }
            else
            {
                gamma = alias_tensor(1, sub.get_output().k());
            }
Davis King's avatar
Davis King committed
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
            beta = gamma;

            params.set_size(gamma.size()+beta.size());

            gamma(params,0) = 1;
            beta(params,gamma.size()) = 0;
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
1029
1030
1031
1032
            if (mode == FC_MODE)
                tt::affine_transform(output, input, g, b);
            else
                tt::affine_transform_conv(output, input, g, b);
Davis King's avatar
Davis King committed
1033
1034
1035
1036
1037
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
Davis King's avatar
Davis King committed
1038
            tensor& /*params_grad*/
Davis King's avatar
Davis King committed
1039
1040
1041
1042
1043
1044
        )
        {
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());

            // We are computing the gradient of dot(gradient_input, computed_output*g + b)
1045
1046
1047
1048
1049
1050
1051
1052
            if (mode == FC_MODE)
            {
                tt::multiply(data_grad, gradient_input, g);
            }
            else
            {
                tt::multiply_conv(data_grad, gradient_input, g);
            }
Davis King's avatar
Davis King committed
1053
1054
        }

1055
1056
        const tensor& get_layer_params() const { return empty_params; }
        tensor& get_layer_params() { return empty_params; }
Davis King's avatar
Davis King committed
1057
1058
1059
1060
1061
1062
1063

        friend void serialize(const affine_& item, std::ostream& out)
        {
            serialize("affine_", out);
            serialize(item.params, out);
            serialize(item.gamma, out);
            serialize(item.beta, out);
1064
            serialize((int)item.mode, out);
Davis King's avatar
Davis King committed
1065
1066
1067
1068
1069
1070
        }

        friend void deserialize(affine_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
            if (version == "bn_con")
            {
                // Since we can build an affine_ from a bn_ we check if that's what is in
                // the stream and if so then just convert it right here.
                unserialize sin(version, in);
                bn_<CONV_MODE> temp;
                deserialize(temp, sin);
                item = temp;
                return;
            }
            else if (version == "bn_fc")
1082
1083
1084
1085
            {
                // Since we can build an affine_ from a bn_ we check if that's what is in
                // the stream and if so then just convert it right here.
                unserialize sin(version, in);
1086
                bn_<FC_MODE> temp;
1087
1088
1089
1090
1091
                deserialize(temp, sin);
                item = temp;
                return;
            }

Davis King's avatar
Davis King committed
1092
            if (version != "affine_")
1093
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
Davis King's avatar
Davis King committed
1094
1095
1096
            deserialize(item.params, in);
            deserialize(item.gamma, in);
            deserialize(item.beta, in);
1097
1098
1099
            int mode;
            deserialize(mode, in);
            item.mode = (layer_mode)mode;
Davis King's avatar
Davis King committed
1100
1101
        }

1102
1103
1104
1105
1106
1107
        friend std::ostream& operator<<(std::ostream& out, const affine_& )
        {
            out << "affine";
            return out;
        }

Davis King's avatar
Davis King committed
1108
    private:
1109
        resizable_tensor params, empty_params; 
Davis King's avatar
Davis King committed
1110
        alias_tensor gamma, beta;
1111
        layer_mode mode;
Davis King's avatar
Davis King committed
1112
1113
1114
    };

    template <typename SUBNET>
1115
    using affine = add_layer<affine_, SUBNET>;
Davis King's avatar
Davis King committed
1116

Davis King's avatar
Davis King committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
// ----------------------------------------------------------------------------------------

    template <
        template<typename> class tag
        >
    class add_prev_
    {
    public:
        add_prev_() 
        {
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
            output.copy_size(sub.get_output());
            tt::add(output, sub.get_output(), layer<tag>(sub).get_output());
        }

        template <typename SUBNET>
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
            // The gradient just flows backwards to the two layers that forward() added
            // together.
            tt::add(sub.get_gradient_input(), sub.get_gradient_input(), gradient_input);
            tt::add(layer<tag>(sub).get_gradient_input(), layer<tag>(sub).get_gradient_input(), gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const add_prev_& , std::ostream& out)
        {
            serialize("add_prev_", out);
        }

        friend void deserialize(add_prev_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "add_prev_")
1163
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_.");
Davis King's avatar
Davis King committed
1164
1165
        }

1166
1167
1168
1169
1170
1171
1172
        friend std::ostream& operator<<(std::ostream& out, const add_prev_& item)
        {
            out << "add_prev";
            return out;
        }


Davis King's avatar
Davis King committed
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    private:
        resizable_tensor params;
    };

    template <
        template<typename> class tag,
        typename SUBNET
        >
    using add_prev = add_layer<add_prev_<tag>, SUBNET>;

    template <typename SUBNET> using add_prev1  = add_prev<tag1, SUBNET>;
    template <typename SUBNET> using add_prev2  = add_prev<tag2, SUBNET>;
    template <typename SUBNET> using add_prev3  = add_prev<tag3, SUBNET>;
    template <typename SUBNET> using add_prev4  = add_prev<tag4, SUBNET>;
    template <typename SUBNET> using add_prev5  = add_prev<tag5, SUBNET>;
    template <typename SUBNET> using add_prev6  = add_prev<tag6, SUBNET>;
    template <typename SUBNET> using add_prev7  = add_prev<tag7, SUBNET>;
    template <typename SUBNET> using add_prev8  = add_prev<tag8, SUBNET>;
    template <typename SUBNET> using add_prev9  = add_prev<tag9, SUBNET>;
    template <typename SUBNET> using add_prev10 = add_prev<tag10, SUBNET>;

    using add_prev1_  = add_prev_<tag1>;
    using add_prev2_  = add_prev_<tag2>;
    using add_prev3_  = add_prev_<tag3>;
    using add_prev4_  = add_prev_<tag4>;
    using add_prev5_  = add_prev_<tag5>;
    using add_prev6_  = add_prev_<tag6>;
    using add_prev7_  = add_prev_<tag7>;
    using add_prev8_  = add_prev_<tag8>;
    using add_prev9_  = add_prev_<tag9>;
    using add_prev10_ = add_prev_<tag10>;

1205
1206
1207
1208
1209
1210
1211
1212
1213
// ----------------------------------------------------------------------------------------

    class relu_
    {
    public:
        relu_() 
        {
        }

Davis King's avatar
Davis King committed
1214
        template <typename SUBNET>
Davis King's avatar
Davis King committed
1215
        void setup (const SUBNET& /*sub*/)
1216
1217
1218
        {
        }

1219
        void forward_inplace(const tensor& input, tensor& output)
1220
        {
1221
            tt::relu(output, input);
1222
1223
        } 

1224
1225
1226
1227
        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
1228
            tensor& 
1229
        )
1230
        {
1231
            tt::relu_gradient(data_grad, computed_output, gradient_input);
1232
1233
1234
1235
1236
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
1237
        friend void serialize(const relu_& , std::ostream& out)
1238
        {
1239
            serialize("relu_", out);
1240
1241
        }

Davis King's avatar
Davis King committed
1242
        friend void deserialize(relu_& , std::istream& in)
1243
        {
1244
1245
1246
            std::string version;
            deserialize(version, in);
            if (version != "relu_")
1247
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
1248
1249
        }

1250
1251
1252
1253
1254
1255
1256
        friend std::ostream& operator<<(std::ostream& out, const relu_& )
        {
            out << "relu";
            return out;
        }


1257
1258
1259
1260
1261
1262
1263
1264
    private:
        resizable_tensor params;
    };


    template <typename SUBNET>
    using relu = add_layer<relu_, SUBNET>;

Davis King's avatar
Davis King committed
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
// ----------------------------------------------------------------------------------------

    class prelu_
    {
    public:
        explicit prelu_(
            float initial_param_value_ = 0.25
        ) : initial_param_value(initial_param_value_)
        {
        }

1276
1277
1278
        float get_initial_param_value (
        ) const { return initial_param_value; }

Davis King's avatar
Davis King committed
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
            params.set_size(1);
            params = initial_param_value;
        }

        template <typename SUBNET>
        void forward(
            const SUBNET& sub, 
            resizable_tensor& data_output
        )
        {
            data_output.copy_size(sub.get_output());
            tt::prelu(data_output, sub.get_output(), params);
        }

        template <typename SUBNET>
        void backward(
            const tensor& gradient_input, 
            SUBNET& sub, 
            tensor& params_grad
        )
        {
            tt::prelu_gradient(sub.get_gradient_input(), sub.get_output(), 
                gradient_input, params, params_grad);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const prelu_& item, std::ostream& out)
        {
            serialize("prelu_", out);
            serialize(item.params, out);
            serialize(item.initial_param_value, out);
        }

        friend void deserialize(prelu_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "prelu_")
1322
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_.");
Davis King's avatar
Davis King committed
1323
1324
1325
1326
            deserialize(item.params, in);
            deserialize(item.initial_param_value, in);
        }

1327
1328
1329
1330
1331
1332
1333
1334
        friend std::ostream& operator<<(std::ostream& out, const prelu_& item)
        {
            out << "prelu\t ("
                << "initial_param_value="<<item.initial_param_value
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
1335
1336
1337
1338
1339
1340
1341
1342
    private:
        resizable_tensor params;
        float initial_param_value;
    };

    template <typename SUBNET>
    using prelu = add_layer<prelu_, SUBNET>;

1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
// ----------------------------------------------------------------------------------------

    class sig_
    {
    public:
        sig_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1353
        void setup (const SUBNET& /*sub*/)
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::sigmoid(output, input);
        } 

        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::sigmoid_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const sig_& , std::ostream& out)
        {
            serialize("sig_", out);
        }

        friend void deserialize(sig_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "sig_")
1385
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_.");
1386
        }
1387

1388
1389
1390
1391
1392
1393
1394
        friend std::ostream& operator<<(std::ostream& out, const sig_& )
        {
            out << "sig";
            return out;
        }


1395
    private:
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
        resizable_tensor params;
    };


    template <typename SUBNET>
    using sig = add_layer<sig_, SUBNET>;

// ----------------------------------------------------------------------------------------

    class htan_
    {
    public:
        htan_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1413
        void setup (const SUBNET& /*sub*/)
1414
1415
1416
1417
1418
1419
1420
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::tanh(output, input);
        } 
1421

1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::tanh_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const htan_& , std::ostream& out)
        {
            serialize("htan_", out);
        }

        friend void deserialize(htan_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "htan_")
1445
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_.");
1446
1447
        }

1448
1449
1450
1451
1452
1453
1454
        friend std::ostream& operator<<(std::ostream& out, const htan_& )
        {
            out << "htan";
            return out;
        }


1455
    private:
1456
1457
1458
        resizable_tensor params;
    };

1459

Davis King's avatar
Davis King committed
1460
    template <typename SUBNET>
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
    using htan = add_layer<htan_, SUBNET>;

// ----------------------------------------------------------------------------------------

    class softmax_
    {
    public:
        softmax_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1473
        void setup (const SUBNET& /*sub*/)
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::softmax(output, input);
        } 

        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::softmax_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const softmax_& , std::ostream& out)
        {
            serialize("softmax_", out);
        }

        friend void deserialize(softmax_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "softmax_")
1505
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_.");
1506
1507
        }

1508
1509
1510
1511
1512
1513
        friend std::ostream& operator<<(std::ostream& out, const softmax_& )
        {
            out << "softmax";
            return out;
        }

1514
1515
1516
1517
1518
1519
    private:
        resizable_tensor params;
    };

    template <typename SUBNET>
    using softmax = add_layer<softmax_, SUBNET>;
1520
1521
1522
1523
1524

// ----------------------------------------------------------------------------------------

}

1525
#endif // DLIB_DNn_LAYERS_H_
1526
1527