residual_layer.py 835 Bytes
Newer Older
1
2
3
import torch.nn as nn
from modules.initializers import GlorotOrthogonal

4

5
6
7
8
9
10
11
class ResidualLayer(nn.Module):
    def __init__(self, units, activation=None):
        super(ResidualLayer, self).__init__()

        self.activation = activation
        self.dense_1 = nn.Linear(units, units)
        self.dense_2 = nn.Linear(units, units)
12

13
        self.reset_params()
14

15
16
17
18
19
20
21
22
23
24
25
26
27
    def reset_params(self):
        GlorotOrthogonal(self.dense_1.weight)
        nn.init.zeros_(self.dense_1.bias)
        GlorotOrthogonal(self.dense_2.weight)
        nn.init.zeros_(self.dense_2.bias)

    def forward(self, inputs):
        x = self.dense_1(inputs)
        if self.activation is not None:
            x = self.activation(x)
        x = self.dense_2(x)
        if self.activation is not None:
            x = self.activation(x)
28
        return inputs + x