"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "31b49c0b5112c8919e447aa247ae4d49c31b229e"
layers.h 46.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNn_LAYERS_H_
#define DLIB_DNn_LAYERS_H_

#include "layers_abstract.h"
#include "tensor.h"
#include "core.h"
#include <iostream>
#include <string>
Davis King's avatar
Davis King committed
11
12
#include "../rand.h"
#include "../string.h"
13
#include "tensor_tools.h"
14
#include "../vectorstream.h"
15
16
17
18
19
20
21


namespace dlib
{

// ----------------------------------------------------------------------------------------

22
23
24
25
26
27
28
    template <
        long _num_filters,
        long _nr,
        long _nc,
        int _stride_y,
        int _stride_x
        >
29
30
31
    class con_
    {
    public:
32

33
34
35
36
37
        static_assert(_num_filters > 0, "The number of filters must be > 0");
        static_assert(_nr > 0, "The number of rows in a filter must be > 0");
        static_assert(_nc > 0, "The number of columns in a filter must be > 0");
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
38

Davis King's avatar
Davis King committed
39
        con_(
40
        )  
41
42
        {}

43
44
45
46
47
48
        long num_filters() const { return _num_filters; }
        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }

Davis King's avatar
Davis King committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        con_ (
            const con_& item
        ) : 
            params(item.params),
            filters(item.filters),
            biases(item.biases)
        {
            // this->conv is non-copyable and basically stateless, so we have to write our
            // own copy to avoid trying to copy it and getting an error.
        }

        con_& operator= (
            const con_& item
        )
        {
            if (this == &item)
                return *this;

            // this->conv is non-copyable and basically stateless, so we have to write our
            // own copy to avoid trying to copy it and getting an error.
            params = item.params;
            filters = item.filters;
            biases = item.biases;
            return *this;
        }

Davis King's avatar
Davis King committed
75
76
        template <typename SUBNET>
        void setup (const SUBNET& sub)
77
        {
78
79
            long num_inputs = _nr*_nc*sub.get_output().k();
            long num_outputs = _num_filters;
Davis King's avatar
Davis King committed
80
            // allocate params for the filters and also for the filter bias values.
81
            params.set_size(num_inputs*_num_filters + _num_filters);
Davis King's avatar
Davis King committed
82

83
            dlib::rand rnd(std::rand());
Davis King's avatar
Davis King committed
84
85
            randomize_parameters(params, num_inputs+num_outputs, rnd);

86
87
            filters = alias_tensor(_num_filters, sub.get_output().k(), _nr, _nc);
            biases = alias_tensor(1,_num_filters);
Davis King's avatar
Davis King committed
88
89
90

            // set the initial bias values to zero
            biases(params,filters.size()) = 0;
91
92
        }

Davis King's avatar
Davis King committed
93
94
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
95
        {
Davis King's avatar
Davis King committed
96
97
98
            conv(output,
                sub.get_output(),
                filters(params,0),
99
                _stride_y,
100
101
102
103
                _stride_x,
                _nr/2,
                _nc/2
                );
Davis King's avatar
Davis King committed
104
105

            tt::add(1,output,1,biases(params,filters.size()));
106
107
        } 

Davis King's avatar
Davis King committed
108
        template <typename SUBNET>
109
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
110
        {
Davis King's avatar
Davis King committed
111
112
113
114
            conv.get_gradient_for_data (gradient_input, filters(params,0), sub.get_gradient_input());
            auto filt = filters(params_grad,0);
            conv.get_gradient_for_filters (gradient_input, sub.get_output(), filt);
            auto b = biases(params_grad, filters.size());
115
            tt::assign_conv_bias_gradient(b, gradient_input);
116
117
118
119
120
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
121
122
123
124
        friend void serialize(const con_& item, std::ostream& out)
        {
            serialize("con_", out);
            serialize(item.params, out);
125
126
127
128
129
            serialize(_num_filters, out);
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
Davis King's avatar
Davis King committed
130
131
132
133
134
135
136
137
138
            serialize(item.filters, out);
            serialize(item.biases, out);
        }

        friend void deserialize(con_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "con_")
139
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
Davis King's avatar
Davis King committed
140
            deserialize(item.params, in);
141
142
143
144
145
146
147
148
149
150
151
152


            long num_filters;
            long nr;
            long nc;
            int stride_y;
            int stride_x;
            deserialize(num_filters, in);
            deserialize(nr, in);
            deserialize(nc, in);
            deserialize(stride_y, in);
            deserialize(stride_x, in);
Davis King's avatar
Davis King committed
153
154
            deserialize(item.filters, in);
            deserialize(item.biases, in);
155
156
157
158
159
160

            if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
            if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
            if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
            if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
            if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
Davis King's avatar
Davis King committed
161
162
        }

163
164
165
166
167
168
169
170
171
172
173
174
175
176

        friend std::ostream& operator<<(std::ostream& out, const con_& item)
        {
            out << "con\t ("
                << "num_filters="<<_num_filters
                << ", nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
                << ")";
            return out;
        }


177
178
179
    private:

        resizable_tensor params;
Davis King's avatar
Davis King committed
180
181
182
183
        alias_tensor filters, biases;

        tt::tensor_conv conv;

184
185
    };

186
187
188
189
190
191
192
193
194
    template <
        long num_filters,
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
195

Davis King's avatar
Davis King committed
196
197
// ----------------------------------------------------------------------------------------

198
199
200
201
202
203
    template <
        long _nr,
        long _nc,
        int _stride_y,
        int _stride_x
        >
Davis King's avatar
Davis King committed
204
205
    class max_pool_
    {
206
207
208
209
        static_assert(_nr > 0, "The number of rows in a filter must be > 0");
        static_assert(_nc > 0, "The number of columns in a filter must be > 0");
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
Davis King's avatar
Davis King committed
210
211
212
213
    public:


        max_pool_(
214
        ) {}
Davis King's avatar
Davis King committed
215
216
217
218
219
220
221

        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }

        max_pool_ (
222
223
            const max_pool_& 
        )  
Davis King's avatar
Davis King committed
224
225
226
        {
            // this->mp is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
227
            mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
Davis King's avatar
Davis King committed
228
229
230
231
232
233
234
235
236
237
238
        }

        max_pool_& operator= (
            const max_pool_& item
        )
        {
            if (this == &item)
                return *this;

            // this->mp is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
239
            mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
Davis King's avatar
Davis King committed
240
241
242
243
244
245
            return *this;
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
246
            mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
Davis King's avatar
Davis King committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
            mp(output, sub.get_output());
        } 

        template <typename SUBNET>
        void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
            mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const max_pool_& item, std::ostream& out)
        {
            serialize("max_pool_", out);
267
268
269
270
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
Davis King's avatar
Davis King committed
271
272
273
274
275
276
277
        }

        friend void deserialize(max_pool_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "max_pool_")
278
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
Davis King's avatar
Davis King committed
279

280
            item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
281
282
283
284
285
286
287
288
289
290
291
292
293
294

            long nr;
            long nc;
            int stride_y;
            int stride_x;

            deserialize(nr, in);
            deserialize(nc, in);
            deserialize(stride_y, in);
            deserialize(stride_x, in);
            if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
            if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
            if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
            if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
Davis King's avatar
Davis King committed
295
296
        }

297
298
299
300
301
302
303
304
305
306
307
308
        friend std::ostream& operator<<(std::ostream& out, const max_pool_& item)
        {
            out << "max_pool ("
                << "nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
                << ")";
            return out;
        }


Davis King's avatar
Davis King committed
309
310
311
    private:


312
        tt::pooling mp;
Davis King's avatar
Davis King committed
313
314
315
        resizable_tensor params;
    };

316
317
318
319
320
321
322
323
    template <
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
Davis King's avatar
Davis King committed
324

325
326
// ----------------------------------------------------------------------------------------

327
328
329
330
331
332
    template <
        long _nr,
        long _nc,
        int _stride_y,
        int _stride_x
        >
333
334
335
    class avg_pool_
    {
    public:
336
337
338
339
        static_assert(_nr > 0, "The number of rows in a filter must be > 0");
        static_assert(_nc > 0, "The number of columns in a filter must be > 0");
        static_assert(_stride_y > 0, "The filter stride must be > 0");
        static_assert(_stride_x > 0, "The filter stride must be > 0");
340
341

        avg_pool_(
342
        ) {}
343
344
345
346
347
348
349

        long nr() const { return _nr; }
        long nc() const { return _nc; }
        long stride_y() const { return _stride_y; }
        long stride_x() const { return _stride_x; }

        avg_pool_ (
350
351
            const avg_pool_& 
        )  
352
353
354
        {
            // this->ap is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
355
            ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
356
357
358
359
360
361
362
363
364
365
366
        }

        avg_pool_& operator= (
            const avg_pool_& item
        )
        {
            if (this == &item)
                return *this;

            // this->ap is non-copyable so we have to write our own copy to avoid trying to
            // copy it and getting an error.
367
            ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
368
369
370
371
372
373
            return *this;
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
374
            ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
            ap(output, sub.get_output());
        } 

        template <typename SUBNET>
        void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
            ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const avg_pool_& item, std::ostream& out)
        {
            serialize("avg_pool_", out);
395
396
397
398
            serialize(_nr, out);
            serialize(_nc, out);
            serialize(_stride_y, out);
            serialize(_stride_x, out);
399
400
401
402
403
404
405
        }

        friend void deserialize(avg_pool_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "avg_pool_")
406
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
407

408
            item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
409
410
411
412
413
414
415
416
417
418
419
420
421
422

            long nr;
            long nc;
            int stride_y;
            int stride_x;

            deserialize(nr, in);
            deserialize(nc, in);
            deserialize(stride_y, in);
            deserialize(stride_x, in);
            if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
            if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
            if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
            if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
423
424
        }

425
426
427
428
429
430
431
432
433
434
        friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item)
        {
            out << "avg_pool ("
                << "nr="<<_nr
                << ", nc="<<_nc
                << ", stride_y="<<_stride_y
                << ", stride_x="<<_stride_x
                << ")";
            return out;
        }
435
436
437
438
439
440
    private:

        tt::pooling ap;
        resizable_tensor params;
    };

441
442
443
444
445
446
447
448
    template <
        long nr,
        long nc,
        int stride_y,
        int stride_x,
        typename SUBNET
        >
    using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
449

450
451
// ----------------------------------------------------------------------------------------

452
    enum layer_mode
453
    {
454
455
        CONV_MODE = 0,
        FC_MODE = 1
456
457
    };

458
459
460
    template <
        layer_mode mode
        >
461
462
463
    class bn_
    {
    public:
464
        bn_() : num_updates(0), running_stats_window_size(1000)
465
466
        {}

467
        explicit bn_(unsigned long window_size) : num_updates(0), running_stats_window_size(window_size)
468
469
470
471
        {}

        layer_mode get_mode() const { return mode; }
        unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
472

473
474
475
        template <typename SUBNET>
        void setup (const SUBNET& sub)
        {
476
            if (mode == FC_MODE)
477
478
479
480
481
482
483
484
485
486
            {
                gamma = alias_tensor(1,
                                sub.get_output().k(),
                                sub.get_output().nr(),
                                sub.get_output().nc());
            }
            else
            {
                gamma = alias_tensor(1, sub.get_output().k());
            }
Davis King's avatar
Davis King committed
487
488
489
490
491
492
            beta = gamma;

            params.set_size(gamma.size()+beta.size());

            gamma(params,0) = 1;
            beta(params,gamma.size()) = 0;
493

494
            running_means.copy_size(gamma(params,0));
495
            running_variances.copy_size(gamma(params,0));
496
            running_means = 0;
497
            running_variances = 1;
498
            num_updates = 0;
499
500
501
502
503
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
Davis King's avatar
Davis King committed
504
505
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
506
507
            if (sub.get_output().num_samples() > 1)
            {
508
                const double decay = 1.0 - num_updates/(num_updates+1.0);
509
510
                if (num_updates <running_stats_window_size)
                    ++num_updates;
511
                if (mode == FC_MODE)
512
                    tt::batch_normalize(output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
513
                else 
514
                    tt::batch_normalize_conv(output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
515
516
517
            }
            else // we are running in testing mode so we just linearly scale the input tensor.
            {
518
                if (mode == FC_MODE)
519
                    tt::batch_normalize_inference(output, sub.get_output(), g, b, running_means, running_variances);
520
                else
521
                    tt::batch_normalize_conv_inference(output, sub.get_output(), g, b, running_means, running_variances);
522
            }
523
524
525
526
527
        } 

        template <typename SUBNET>
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
        {
Davis King's avatar
Davis King committed
528
529
530
            auto g = gamma(params,0);
            auto g_grad = gamma(params_grad, 0);
            auto b_grad = beta(params_grad, gamma.size());
531
            if (mode == FC_MODE)
532
533
534
                tt::batch_normalize_gradient(gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
            else
                tt::batch_normalize_conv_gradient(gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
535
536
537
538
539
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
540
541
        friend void serialize(const bn_& item, std::ostream& out)
        {
542
543
544
545
            if (mode == CONV_MODE)
                serialize("bn_con", out);
            else // if FC_MODE
                serialize("bn_fc", out);
Davis King's avatar
Davis King committed
546
547
548
549
550
            serialize(item.params, out);
            serialize(item.gamma, out);
            serialize(item.beta, out);
            serialize(item.means, out);
            serialize(item.invstds, out);
551
            serialize(item.running_means, out);
552
            serialize(item.running_variances, out);
553
554
            serialize(item.num_updates, out);
            serialize(item.running_stats_window_size, out);
Davis King's avatar
Davis King committed
555
556
557
558
559
560
561
        }

        friend void deserialize(bn_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "bn_")
562
563
564
565
566
567
568
569
570
571
572
573
574
            {
                if (mode == CONV_MODE) 
                {
                    if (version != "bn_con")
                        throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
                }
                else // must be in FC_MODE
                {
                    if (version != "bn_fc")
                        throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
                }
            }

Davis King's avatar
Davis King committed
575
576
577
578
579
            deserialize(item.params, in);
            deserialize(item.gamma, in);
            deserialize(item.beta, in);
            deserialize(item.means, in);
            deserialize(item.invstds, in);
580
            deserialize(item.running_means, in);
581
            deserialize(item.running_variances, in);
582
583
            deserialize(item.num_updates, in);
            deserialize(item.running_stats_window_size, in);
584
585
586
587
588
589
590
591
592

            // if this is the older "bn_" version then check its saved mode value and make
            // sure it is the one we are really using.  
            if (version == "bn_")
            {
                int _mode;
                deserialize(_mode, in);
                if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_");

593
                // We also need to flip the running_variances around since the previous
594
                // format saved the inverse standard deviations instead of variances.
595
                item.running_variances = 1.0f/squared(mat(item.running_variances)) - tt::BATCH_NORM_EPS;
596
            }
Davis King's avatar
Davis King committed
597
598
        }

599
600
601
602
603
604
605
606
607
        friend std::ostream& operator<<(std::ostream& out, const bn_& item)
        {
            if (mode == CONV_MODE)
                out << "bn_con";
            else
                out << "bn_fc";
            return out;
        }

608
609
    private:

610
611
        friend class affine_;

612
        resizable_tensor params;
Davis King's avatar
Davis King committed
613
        alias_tensor gamma, beta;
614
        resizable_tensor means, running_means;
615
        resizable_tensor invstds, running_variances;
616
617
        unsigned long num_updates;
        unsigned long running_stats_window_size;
618
619
620
    };

    template <typename SUBNET>
621
622
623
    using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
    template <typename SUBNET>
    using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
624

625
626
// ----------------------------------------------------------------------------------------

627
628
    enum fc_bias_mode
    {
629
630
631
632
        FC_HAS_BIAS = 0,
        FC_NO_BIAS = 1
    };

633
634
635
636
637
638
639
640
641
642
    struct num_fc_outputs
    {
        num_fc_outputs(unsigned long n) : num_outputs(n) {}
        unsigned long num_outputs;
    };

    template <
        unsigned long num_outputs_,
        fc_bias_mode bias_mode
        >
643
644
    class fc_
    {
645
646
        static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");

647
    public:
648
        fc_() : num_outputs(num_outputs_), num_inputs(0)
649
650
651
        {
        }

652
        fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0) {}
653
654
655
656

        unsigned long get_num_outputs (
        ) const { return num_outputs; }

657
658
659
        fc_bias_mode get_bias_mode (
        ) const { return bias_mode; }

Davis King's avatar
Davis King committed
660
661
        template <typename SUBNET>
        void setup (const SUBNET& sub)
662
663
        {
            num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
664
665
666
667
            if (bias_mode == FC_HAS_BIAS)
                params.set_size(num_inputs+1, num_outputs);
            else
                params.set_size(num_inputs, num_outputs);
668

669
            dlib::rand rnd(std::rand());
670
            randomize_parameters(params, num_inputs+num_outputs, rnd);
671
672
673
674
675
676
677
678
679

            weights = alias_tensor(num_inputs, num_outputs);

            if (bias_mode == FC_HAS_BIAS)
            {
                biases = alias_tensor(1,num_outputs);
                // set the initial bias values to zero
                biases(params,weights.size()) = 0;
            }
680
681
        }

Davis King's avatar
Davis King committed
682
683
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
684
        {
685
            output.set_size(sub.get_output().num_samples(), num_outputs);
686

687
688
689
690
691
692
693
            auto w = weights(params, 0);
            tt::gemm(0,output, 1,sub.get_output(),false, w,false);
            if (bias_mode == FC_HAS_BIAS)
            {
                auto b = biases(params, weights.size());
                tt::add(1,output,1,b);
            }
694
695
        } 

Davis King's avatar
Davis King committed
696
        template <typename SUBNET>
697
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
698
        {
699
700
701
702
703
704
705
706
            // compute the gradient of the weight parameters.  
            auto pw = weights(params_grad, 0);
            tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);

            if (bias_mode == FC_HAS_BIAS)
            {
                // compute the gradient of the bias parameters.  
                auto pb = biases(params_grad, weights.size());
707
                tt::assign_bias_gradient(pb, gradient_input);
708
            }
709
710

            // compute the gradient for the data
711
712
            auto w = weights(params, 0);
            tt::gemm(1,sub.get_gradient_input(), 1,gradient_input,false, w,true);
713
714
715
716
717
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

718
719
720
721
722
723
        friend void serialize(const fc_& item, std::ostream& out)
        {
            serialize("fc_", out);
            serialize(item.num_outputs, out);
            serialize(item.num_inputs, out);
            serialize(item.params, out);
724
725
            serialize(item.weights, out);
            serialize(item.biases, out);
726
            serialize((int)bias_mode, out);
727
728
729
730
731
732
733
        }

        friend void deserialize(fc_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "fc_")
734
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
735

736
737
738
            deserialize(item.num_outputs, in);
            deserialize(item.num_inputs, in);
            deserialize(item.params, in);
739
740
741
742
            deserialize(item.weights, in);
            deserialize(item.biases, in);
            int bmode = 0;
            deserialize(bmode, in);
743
            if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
744
745
        }

746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        friend std::ostream& operator<<(std::ostream& out, const fc_& item)
        {
            if (bias_mode == FC_HAS_BIAS)
            {
                out << "fc\t ("
                    << "num_outputs="<<item.num_outputs
                    << ")";
            }
            else
            {
                out << "fc_no_bias ("
                    << "num_outputs="<<item.num_outputs
                    << ")";
            }
            return out;
        }

763
764
765
766
767
    private:

        unsigned long num_outputs;
        unsigned long num_inputs;
        resizable_tensor params;
768
        alias_tensor weights, biases;
769
770
    };

771
772
773
774
    template <
        unsigned long num_outputs,
        typename SUBNET
        >
775
776
777
778
779
780
781
    using fc = add_layer<fc_<num_outputs,FC_HAS_BIAS>, SUBNET>;

    template <
        unsigned long num_outputs,
        typename SUBNET
        >
    using fc_no_bias = add_layer<fc_<num_outputs,FC_NO_BIAS>, SUBNET>;
782

Davis King's avatar
Davis King committed
783
784
785
786
787
788
789
790
791
792
// ----------------------------------------------------------------------------------------

    class dropout_
    {
    public:
        explicit dropout_(
            float drop_rate_ = 0.5
        ) :
            drop_rate(drop_rate_)
        {
793
            DLIB_CASSERT(0 <= drop_rate && drop_rate <= 1,"");
Davis King's avatar
Davis King committed
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        }

        // We have to add a copy constructor and assignment operator because the rnd object
        // is non-copyable.
        dropout_(
            const dropout_& item
        ) : drop_rate(item.drop_rate), mask(item.mask)
        {}

        dropout_& operator= (
            const dropout_& item
        )
        {
            if (this == &item)
                return *this;

            drop_rate = item.drop_rate;
            mask = item.mask;
            return *this;
        }

        float get_drop_rate (
        ) const { return drop_rate; }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            // create a random mask and use it to filter the data
            mask.copy_size(input);
            rnd.fill_uniform(mask);
            tt::threshold(mask, drop_rate);
            tt::multiply(output, input, mask);
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& /*params_grad*/
        )
        {
            tt::multiply(data_grad, mask, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const dropout_& item, std::ostream& out)
        {
            serialize("dropout_", out);
            serialize(item.drop_rate, out);
            serialize(item.mask, out);
        }

        friend void deserialize(dropout_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "dropout_")
856
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_.");
Davis King's avatar
Davis King committed
857
858
859
860
            deserialize(item.drop_rate, in);
            deserialize(item.mask, in);
        }

861
862
863
864
865
866
867
868
        friend std::ostream& operator<<(std::ostream& out, const dropout_& item)
        {
            out << "dropout\t ("
                << "drop_rate="<<item.drop_rate
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
869
870
871
872
873
874
875
876
877
878
879
880
    private:
        float drop_rate;
        resizable_tensor mask;

        tt::tensor_rand rnd;
        resizable_tensor params; // unused
    };


    template <typename SUBNET>
    using dropout = add_layer<dropout_, SUBNET>;

881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
// ----------------------------------------------------------------------------------------

    class multiply_
    {
    public:
        explicit multiply_(
            float val_ = 0.5
        ) :
            val(val_)
        {
        }

        multiply_ (
            const dropout_& item
        ) : val(1-item.get_drop_rate()) {}

        float get_multiply_value (
        ) const { return val; }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::affine_transform(output, input, val, 0);
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& /*params_grad*/
        )
        {
            tt::affine_transform(data_grad, gradient_input, val, 0);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const multiply_& item, std::ostream& out)
        {
            serialize("multiply_", out);
            serialize(item.val, out);
        }

        friend void deserialize(multiply_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version == "dropout_")
            {
                // Since we can build a multiply_ from a dropout_ we check if that's what
                // is in the stream and if so then just convert it right here.
                unserialize sin(version, in);
                dropout_ temp;
                deserialize(temp, sin);
                item = temp;
                return;
            }

            if (version != "multiply_")
944
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_.");
945
946
947
            deserialize(item.val, in);
        }

948
949
950
951
952
953
954
955
        friend std::ostream& operator<<(std::ostream& out, const multiply_& item)
        {
            out << "multiply ("
                << "val="<<item.val
                << ")";
            return out;
        }

956
957
958
959
960
961
962
963
    private:
        float val;
        resizable_tensor params; // unused
    };

    template <typename SUBNET>
    using multiply = add_layer<multiply_, SUBNET>;

Davis King's avatar
Davis King committed
964
965
966
967
968
969
// ----------------------------------------------------------------------------------------

    class affine_
    {
    public:
        affine_(
970
971
972
        ) : mode(FC_MODE)
        {
        }
Davis King's avatar
Davis King committed
973

974
        affine_(
975
976
977
978
979
980
981
982
983
984
            layer_mode mode_
        ) : mode(mode_)
        {
        }

        template <
            layer_mode bnmode
            >
        affine_(
            const bn_<bnmode>& item
985
986
987
988
        )
        {
            gamma = item.gamma;
            beta = item.beta;
989
            mode = bnmode;
990
991
992
993
994
995
996
997
998
999

            params.copy_size(item.params);

            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
            
            resizable_tensor temp(item.params);
            auto sg = gamma(temp,0);
            auto sb = beta(temp,gamma.size());

1000
            g = pointwise_multiply(mat(sg), 1.0f/sqrt(mat(item.running_variances)+tt::BATCH_NORM_EPS));
1001
1002
1003
1004
1005
            b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means));
        }

        layer_mode get_mode() const { return mode; }

Davis King's avatar
Davis King committed
1006
1007
1008
        template <typename SUBNET>
        void setup (const SUBNET& sub)
        {
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
            if (mode == FC_MODE)
            {
                gamma = alias_tensor(1,
                                sub.get_output().k(),
                                sub.get_output().nr(),
                                sub.get_output().nc());
            }
            else
            {
                gamma = alias_tensor(1, sub.get_output().k());
            }
Davis King's avatar
Davis King committed
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
            beta = gamma;

            params.set_size(gamma.size()+beta.size());

            gamma(params,0) = 1;
            beta(params,gamma.size()) = 0;
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());
1032
1033
1034
1035
            if (mode == FC_MODE)
                tt::affine_transform(output, input, g, b);
            else
                tt::affine_transform_conv(output, input, g, b);
Davis King's avatar
Davis King committed
1036
1037
1038
1039
1040
        } 

        void backward_inplace(
            const tensor& gradient_input, 
            tensor& data_grad, 
Davis King's avatar
Davis King committed
1041
            tensor& /*params_grad*/
Davis King's avatar
Davis King committed
1042
1043
1044
1045
1046
1047
        )
        {
            auto g = gamma(params,0);
            auto b = beta(params,gamma.size());

            // We are computing the gradient of dot(gradient_input, computed_output*g + b)
1048
1049
1050
1051
1052
1053
1054
1055
            if (mode == FC_MODE)
            {
                tt::multiply(data_grad, gradient_input, g);
            }
            else
            {
                tt::multiply_conv(data_grad, gradient_input, g);
            }
Davis King's avatar
Davis King committed
1056
1057
        }

1058
1059
        const tensor& get_layer_params() const { return empty_params; }
        tensor& get_layer_params() { return empty_params; }
Davis King's avatar
Davis King committed
1060
1061
1062
1063
1064
1065
1066

        friend void serialize(const affine_& item, std::ostream& out)
        {
            serialize("affine_", out);
            serialize(item.params, out);
            serialize(item.gamma, out);
            serialize(item.beta, out);
1067
            serialize((int)item.mode, out);
Davis King's avatar
Davis King committed
1068
1069
1070
1071
1072
1073
        }

        friend void deserialize(affine_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
            if (version == "bn_con")
            {
                // Since we can build an affine_ from a bn_ we check if that's what is in
                // the stream and if so then just convert it right here.
                unserialize sin(version, in);
                bn_<CONV_MODE> temp;
                deserialize(temp, sin);
                item = temp;
                return;
            }
            else if (version == "bn_fc")
1085
1086
1087
1088
            {
                // Since we can build an affine_ from a bn_ we check if that's what is in
                // the stream and if so then just convert it right here.
                unserialize sin(version, in);
1089
                bn_<FC_MODE> temp;
1090
1091
1092
1093
1094
                deserialize(temp, sin);
                item = temp;
                return;
            }

Davis King's avatar
Davis King committed
1095
            if (version != "affine_")
1096
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
Davis King's avatar
Davis King committed
1097
1098
1099
            deserialize(item.params, in);
            deserialize(item.gamma, in);
            deserialize(item.beta, in);
1100
1101
1102
            int mode;
            deserialize(mode, in);
            item.mode = (layer_mode)mode;
Davis King's avatar
Davis King committed
1103
1104
        }

1105
1106
1107
1108
1109
1110
        friend std::ostream& operator<<(std::ostream& out, const affine_& )
        {
            out << "affine";
            return out;
        }

Davis King's avatar
Davis King committed
1111
    private:
1112
        resizable_tensor params, empty_params; 
Davis King's avatar
Davis King committed
1113
        alias_tensor gamma, beta;
1114
        layer_mode mode;
Davis King's avatar
Davis King committed
1115
1116
1117
    };

    template <typename SUBNET>
1118
    using affine = add_layer<affine_, SUBNET>;
Davis King's avatar
Davis King committed
1119

Davis King's avatar
Davis King committed
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
// ----------------------------------------------------------------------------------------

    template <
        template<typename> class tag
        >
    class add_prev_
    {
    public:
        add_prev_() 
        {
        }

        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
        }

        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
        {
            output.copy_size(sub.get_output());
            tt::add(output, sub.get_output(), layer<tag>(sub).get_output());
        }

        template <typename SUBNET>
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
        {
            // The gradient just flows backwards to the two layers that forward() added
            // together.
            tt::add(sub.get_gradient_input(), sub.get_gradient_input(), gradient_input);
            tt::add(layer<tag>(sub).get_gradient_input(), layer<tag>(sub).get_gradient_input(), gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const add_prev_& , std::ostream& out)
        {
            serialize("add_prev_", out);
        }

        friend void deserialize(add_prev_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "add_prev_")
1166
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_.");
Davis King's avatar
Davis King committed
1167
1168
        }

1169
1170
1171
1172
1173
1174
1175
        friend std::ostream& operator<<(std::ostream& out, const add_prev_& item)
        {
            out << "add_prev";
            return out;
        }


Davis King's avatar
Davis King committed
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    private:
        resizable_tensor params;
    };

    template <
        template<typename> class tag,
        typename SUBNET
        >
    using add_prev = add_layer<add_prev_<tag>, SUBNET>;

    template <typename SUBNET> using add_prev1  = add_prev<tag1, SUBNET>;
    template <typename SUBNET> using add_prev2  = add_prev<tag2, SUBNET>;
    template <typename SUBNET> using add_prev3  = add_prev<tag3, SUBNET>;
    template <typename SUBNET> using add_prev4  = add_prev<tag4, SUBNET>;
    template <typename SUBNET> using add_prev5  = add_prev<tag5, SUBNET>;
    template <typename SUBNET> using add_prev6  = add_prev<tag6, SUBNET>;
    template <typename SUBNET> using add_prev7  = add_prev<tag7, SUBNET>;
    template <typename SUBNET> using add_prev8  = add_prev<tag8, SUBNET>;
    template <typename SUBNET> using add_prev9  = add_prev<tag9, SUBNET>;
    template <typename SUBNET> using add_prev10 = add_prev<tag10, SUBNET>;

    using add_prev1_  = add_prev_<tag1>;
    using add_prev2_  = add_prev_<tag2>;
    using add_prev3_  = add_prev_<tag3>;
    using add_prev4_  = add_prev_<tag4>;
    using add_prev5_  = add_prev_<tag5>;
    using add_prev6_  = add_prev_<tag6>;
    using add_prev7_  = add_prev_<tag7>;
    using add_prev8_  = add_prev_<tag8>;
    using add_prev9_  = add_prev_<tag9>;
    using add_prev10_ = add_prev_<tag10>;

1208
1209
1210
1211
1212
1213
1214
1215
1216
// ----------------------------------------------------------------------------------------

    class relu_
    {
    public:
        relu_() 
        {
        }

Davis King's avatar
Davis King committed
1217
        template <typename SUBNET>
Davis King's avatar
Davis King committed
1218
        void setup (const SUBNET& /*sub*/)
1219
1220
1221
        {
        }

1222
        void forward_inplace(const tensor& input, tensor& output)
1223
        {
1224
            tt::relu(output, input);
1225
1226
        } 

1227
1228
1229
1230
        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
1231
            tensor& 
1232
        )
1233
        {
1234
            tt::relu_gradient(data_grad, computed_output, gradient_input);
1235
1236
1237
1238
1239
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
1240
        friend void serialize(const relu_& , std::ostream& out)
1241
        {
1242
            serialize("relu_", out);
1243
1244
        }

Davis King's avatar
Davis King committed
1245
        friend void deserialize(relu_& , std::istream& in)
1246
        {
1247
1248
1249
            std::string version;
            deserialize(version, in);
            if (version != "relu_")
1250
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
1251
1252
        }

1253
1254
1255
1256
1257
1258
1259
        friend std::ostream& operator<<(std::ostream& out, const relu_& )
        {
            out << "relu";
            return out;
        }


1260
1261
1262
1263
1264
1265
1266
1267
    private:
        resizable_tensor params;
    };


    template <typename SUBNET>
    using relu = add_layer<relu_, SUBNET>;

Davis King's avatar
Davis King committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
// ----------------------------------------------------------------------------------------

    class prelu_
    {
    public:
        explicit prelu_(
            float initial_param_value_ = 0.25
        ) : initial_param_value(initial_param_value_)
        {
        }

1279
1280
1281
        float get_initial_param_value (
        ) const { return initial_param_value; }

Davis King's avatar
Davis King committed
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
        template <typename SUBNET>
        void setup (const SUBNET& /*sub*/)
        {
            params.set_size(1);
            params = initial_param_value;
        }

        template <typename SUBNET>
        void forward(
            const SUBNET& sub, 
            resizable_tensor& data_output
        )
        {
            data_output.copy_size(sub.get_output());
            tt::prelu(data_output, sub.get_output(), params);
        }

        template <typename SUBNET>
        void backward(
            const tensor& gradient_input, 
            SUBNET& sub, 
            tensor& params_grad
        )
        {
            tt::prelu_gradient(sub.get_gradient_input(), sub.get_output(), 
                gradient_input, params, params_grad);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const prelu_& item, std::ostream& out)
        {
            serialize("prelu_", out);
            serialize(item.params, out);
            serialize(item.initial_param_value, out);
        }

        friend void deserialize(prelu_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "prelu_")
1325
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_.");
Davis King's avatar
Davis King committed
1326
1327
1328
1329
            deserialize(item.params, in);
            deserialize(item.initial_param_value, in);
        }

1330
1331
1332
1333
1334
1335
1336
1337
        friend std::ostream& operator<<(std::ostream& out, const prelu_& item)
        {
            out << "prelu\t ("
                << "initial_param_value="<<item.initial_param_value
                << ")";
            return out;
        }

Davis King's avatar
Davis King committed
1338
1339
1340
1341
1342
1343
1344
1345
    private:
        resizable_tensor params;
        float initial_param_value;
    };

    template <typename SUBNET>
    using prelu = add_layer<prelu_, SUBNET>;

1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
// ----------------------------------------------------------------------------------------

    class sig_
    {
    public:
        sig_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1356
        void setup (const SUBNET& /*sub*/)
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::sigmoid(output, input);
        } 

        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::sigmoid_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const sig_& , std::ostream& out)
        {
            serialize("sig_", out);
        }

        friend void deserialize(sig_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "sig_")
1388
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_.");
1389
        }
1390

1391
1392
1393
1394
1395
1396
1397
        friend std::ostream& operator<<(std::ostream& out, const sig_& )
        {
            out << "sig";
            return out;
        }


1398
    private:
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
        resizable_tensor params;
    };


    template <typename SUBNET>
    using sig = add_layer<sig_, SUBNET>;

// ----------------------------------------------------------------------------------------

    class htan_
    {
    public:
        htan_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1416
        void setup (const SUBNET& /*sub*/)
1417
1418
1419
1420
1421
1422
1423
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::tanh(output, input);
        } 
1424

1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::tanh_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const htan_& , std::ostream& out)
        {
            serialize("htan_", out);
        }

        friend void deserialize(htan_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "htan_")
1448
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_.");
1449
1450
        }

1451
1452
1453
1454
1455
1456
1457
        friend std::ostream& operator<<(std::ostream& out, const htan_& )
        {
            out << "htan";
            return out;
        }


1458
    private:
1459
1460
1461
        resizable_tensor params;
    };

1462

Davis King's avatar
Davis King committed
1463
    template <typename SUBNET>
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
    using htan = add_layer<htan_, SUBNET>;

// ----------------------------------------------------------------------------------------

    class softmax_
    {
    public:
        softmax_() 
        {
        }

        template <typename SUBNET>
Davis King's avatar
Davis King committed
1476
        void setup (const SUBNET& /*sub*/)
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
        {
        }

        void forward_inplace(const tensor& input, tensor& output)
        {
            tt::softmax(output, input);
        } 

        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& 
        )
        {
            tt::softmax_gradient(data_grad, computed_output, gradient_input);
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

        friend void serialize(const softmax_& , std::ostream& out)
        {
            serialize("softmax_", out);
        }

        friend void deserialize(softmax_& , std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "softmax_")
1508
                throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_.");
1509
1510
        }

1511
1512
1513
1514
1515
1516
        friend std::ostream& operator<<(std::ostream& out, const softmax_& )
        {
            out << "softmax";
            return out;
        }

1517
1518
1519
1520
1521
1522
    private:
        resizable_tensor params;
    };

    template <typename SUBNET>
    using softmax = add_layer<softmax_, SUBNET>;
1523
1524
1525
1526
1527

// ----------------------------------------------------------------------------------------

}

1528
#endif // DLIB_DNn_LAYERS_H_
1529
1530