"examples/vscode:/vscode.git/clone" did not exist on "0ff70821c9b0b991197fa7f3264bf9dd78b8d4b3"
linear.py 2.24 KB
Newer Older
1
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
5
from torch.nn import Parameter
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class Linear(nn.Module):
    """Linear layer with column parallelism.
    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip
                       adding bias but instead return it.
    """

27
    def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
28
29
30
31
32
33
34
        super(Linear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add

35
36
37
38
39
40
        self.weight = Parameter(
            torch.empty(
                self.output_size,
                self.input_size,
            )
        )
41
42
43
44
45
46
47
        init.normal_(self.weight)
        if bias:
            self.bias = Parameter(torch.empty(self.output_size))
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
48
            self.register_parameter("bias", None)
49
50
51
52
53
54
55
56
57
58
59
60

    def forward(self, input_):
        # Matrix multiply.
        bias = self.bias if not self.skip_bias_add else None
        output = F.linear(input_, self.weight, bias)

        if self.skip_bias_add:
            return output, self.bias
        else:
            return output

    def __repr__(self):
61
62
63
64
        return (
            f"Linear(in_features={self.input_size}, out_features={self.output_size}, "
            + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})"
        )