layers.h 5.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNn_LAYERS_H_
#define DLIB_DNn_LAYERS_H_

#include "layers_abstract.h"
#include "tensor.h"
#include "core.h"
#include <iostream>
#include <string>
Davis King's avatar
Davis King committed
11
12
#include "../rand.h"
#include "../string.h"
13
14
15
16
17
18
19
20
21
22
23
24
25


namespace dlib
{

// ----------------------------------------------------------------------------------------

    class con_
    {
    public:
        con_()
        {}

Davis King's avatar
Davis King committed
26
27
        template <typename SUBNET>
        void setup (const SUBNET& sub)
28
29
30
31
        {
            // TODO
        }

Davis King's avatar
Davis King committed
32
33
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
34
35
36
37
        {
            // TODO
        } 

Davis King's avatar
Davis King committed
38
        template <typename SUBNET>
39
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
40
41
42
43
44
45
46
47
48
49
50
51
        {
            // TODO
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

    private:

        resizable_tensor params;
    };

Davis King's avatar
Davis King committed
52
53
    template <typename SUBNET>
    using con = add_layer<con_, SUBNET>;
54
55
56
57
58
59

// ----------------------------------------------------------------------------------------

    class fc_
    {
    public:
Davis King's avatar
Davis King committed
60
        fc_() : num_outputs(1), num_inputs(0)
61
62
63
        {
        }

64
65
        explicit fc_(
            unsigned long num_outputs_
Davis King's avatar
Davis King committed
66
        ) : num_outputs(num_outputs_), num_inputs(0)
67
68
69
70
71
72
        {
        }

        unsigned long get_num_outputs (
        ) const { return num_outputs; }

Davis King's avatar
Davis King committed
73
74
        template <typename SUBNET>
        void setup (const SUBNET& sub)
75
76
77
78
        {
            num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
            params.set_size(num_inputs, num_outputs);

79
            dlib::rand rnd("fc_"+cast_to_string(num_outputs));
80
81
82
            randomize_parameters(params, num_inputs+num_outputs, rnd);
        }

Davis King's avatar
Davis King committed
83
84
        template <typename SUBNET>
        void forward(const SUBNET& sub, resizable_tensor& output)
85
        {
86
            output.set_size(sub.get_output().num_samples(), num_outputs);
87
88
89
90

            output = mat(sub.get_output())*mat(params);
        } 

Davis King's avatar
Davis King committed
91
        template <typename SUBNET>
92
        void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
93
94
        {
            // compute the gradient of the parameters.  
95
            params_grad = trans(mat(sub.get_output()))*mat(gradient_input);
96
97
98
99
100
101
102
103

            // compute the gradient for the data
            sub.get_gradient_input() += mat(gradient_input)*trans(mat(params));
        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        friend void serialize(const fc_& item, std::ostream& out)
        {
            serialize("fc_", out);
            serialize(item.num_outputs, out);
            serialize(item.num_inputs, out);
            serialize(item.params, out);
        }

        friend void deserialize(fc_& item, std::istream& in)
        {
            std::string version;
            deserialize(version, in);
            if (version != "fc_")
                throw serialization_error("Unexpected version found while deserializing dlib::fc_.");
            deserialize(item.num_outputs, in);
            deserialize(item.num_inputs, in);
            deserialize(item.params, in);
        }

123
124
125
126
127
128
129
130
    private:

        unsigned long num_outputs;
        unsigned long num_inputs;
        resizable_tensor params;
    };


Davis King's avatar
Davis King committed
131
132
    template <typename SUBNET>
    using fc = add_layer<fc_, SUBNET>;
133
134
135
136
137
138
139
140
141
142

// ----------------------------------------------------------------------------------------

    class relu_
    {
    public:
        relu_() 
        {
        }

Davis King's avatar
Davis King committed
143
144
        template <typename SUBNET>
        void setup (const SUBNET& sub)
145
146
147
        {
        }

148
        void forward_inplace(const tensor& input, tensor& output)
149
        {
150
            output = lowerbound(mat(input), 0);
151
152
        } 

153
154
155
156
157
158
        void backward_inplace(
            const tensor& computed_output,
            const tensor& gradient_input, 
            tensor& data_grad, 
            tensor& params_grad
        )
159
160
        {
            const float* grad = gradient_input.host();
161
162
163
            const float* in = computed_output.host();
            float* out = data_grad.host();
            for (unsigned long i = 0; i < computed_output.size(); ++i)
164
165
            {
                if (in[i] > 0)
166
167
168
                    out[i] = grad[i];
                else
                    out[i] = 0;
169
170
171
172
173
174
175
            }

        }

        const tensor& get_layer_params() const { return params; }
        tensor& get_layer_params() { return params; }

Davis King's avatar
Davis King committed
176
        friend void serialize(const relu_& , std::ostream& out)
177
        {
178
            serialize("relu_", out);
179
180
        }

Davis King's avatar
Davis King committed
181
        friend void deserialize(relu_& , std::istream& in)
182
        {
183
184
185
186
            std::string version;
            deserialize(version, in);
            if (version != "relu_")
                throw serialization_error("Unexpected version found while deserializing dlib::relu_.");
187
188
189
190
191
192
193
194
        }


    private:

        resizable_tensor params;
    };

195

Davis King's avatar
Davis King committed
196
    template <typename SUBNET>
197
    using relu = add_layer<relu_, SUBNET>;
198
199
200
201
202

// ----------------------------------------------------------------------------------------

}

203
#endif // DLIB_DNn_LAYERS_H_
204
205