"examples/pytorch/graphsage/dist/README.md" did not exist on "e36c5db6140ab8628995046097d226b768f3dbd3"
highway.py 1.7 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
5
6
7
8
9
10
11
12

import torch

from torch import nn


class Highway(torch.nn.Module):
    """
Myle Ott's avatar
Myle Ott committed
13
14
    A `Highway layer <https://arxiv.org/abs/1505.00387>`_.
    Adopted from the AllenNLP implementation.
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    """

    def __init__(
            self,
            input_dim: int,
            num_layers: int = 1
    ):
        super(Highway, self).__init__()
        self.input_dim = input_dim
        self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2)
                                     for _ in range(num_layers)])
        self.activation = nn.ReLU()

        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            # As per comment in AllenNLP:
            # We should bias the highway layer to just carry its input forward.  We do that by
            # setting the bias on `B(x)` to be positive, because that means `g` will be biased to
            # be high, so we will carry the input forward.  The bias on `B(x)` is the second half
            # of the bias vector in each Linear layer.
            nn.init.constant_(layer.bias[self.input_dim:], 1)

            nn.init.constant_(layer.bias[:self.input_dim], 0)
            nn.init.xavier_normal_(layer.weight)

    def forward(
            self,
            x: torch.Tensor
    ):
        for layer in self.layers:
            projection = layer(x)
            proj_x, gate = projection.chunk(2, dim=-1)
            proj_x = self.activation(proj_x)
Myle Ott's avatar
Myle Ott committed
50
            gate = torch.sigmoid(gate)
Haoran Li's avatar
Haoran Li committed
51
            x = gate * x + (gate.new_tensor([1]) - gate) * proj_x
52
        return x