layer_norm.py 12.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2022, Tri Dao.
2
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
Tri Dao's avatar
Tri Dao committed
3

4
5
6
7
8
9
import torch
from torch.nn import init

import dropout_layer_norm


Tri Dao's avatar
Tri Dao committed
10
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
11
                                    residual_in_fp32=False, is_rms_norm=False):
12
13
14
15
16
17
18
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
    rowscale = rowscale.view(-1) if rowscale is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
19
        x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
20
        1.0, 0, None, residual_in_fp32, is_rms_norm
21
22
23
24
25
26
    )
    # dmask is None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


Tri Dao's avatar
Tri Dao committed
27
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
Tri Dao's avatar
Tri Dao committed
28
                                     dropout_p, has_residual, is_rms_norm=False):
29
    """ Assume that arguments are contiguous
Tri Dao's avatar
Tri Dao committed
30
31
32
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + x1 was not returned in the fwd).
    x0 must not be None if we have colscale.
33
34
35
36
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(xmat.shape)
Tri Dao's avatar
Tri Dao committed
37
38
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
39
    rowscale = rowscale.view(-1) if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
40
41
42
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
    dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
43
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
Tri Dao's avatar
Tri Dao committed
44
        dropout_p, 1.0, 0, has_residual, is_rms_norm
45
46
47
48
49
50
51
52
53
54
55
    )
    # dx1mat is None if not has_residual
    if colscale is None:
        return dx0mat, dx1mat, dgamma, dbeta
    else:
        dcolscale = rest[0]
        return dx0mat, dx1mat, dgamma, dbeta, dcolscale


def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
                                           dropout_p, epsilon, rowscale_const, out_numrows,
Tri Dao's avatar
Tri Dao committed
56
                                           residual_in_fp32=False, is_rms_norm=False):
57
58
59
60
61
62
63
64
65
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
        x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
66
        rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
67
68
69
70
71
72
73
74
    )
    # dmask is None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
                                            x0_subset, out_subset, dropout_p, rowscale_const,
Tri Dao's avatar
Tri Dao committed
75
                                            x0_numrows, has_residual, is_rms_norm=False):
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    """ Assume that arguments are contiguous
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + x1 was not returned in the fwd).
    x0 must not be None if we have colscale.
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(-1, hidden_size)
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
    dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
Tri Dao's avatar
Tri Dao committed
92
        dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
93
94
    )
    # dx1mat is None if not has_residual
Tri Dao's avatar
Tri Dao committed
95
96
97
98
99
    if colscale is None:
        return dx0mat, dx1mat, dgamma, dbeta
    else:
        dcolscale = rest[0]
        return dx0mat, dx1mat, dgamma, dbeta, dcolscale
100
101


Tri Dao's avatar
Tri Dao committed
102
class DropoutAddLayerNormFn(torch.autograd.Function):
103
    @staticmethod
Tri Dao's avatar
Tri Dao committed
104
105
    def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
                residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
106
107
108
        x0 = x0.contiguous()
        x1 = x1.contiguous() if x1 is not None else None
        gamma = gamma.contiguous()
Tri Dao's avatar
Tri Dao committed
109
        beta = beta.contiguous() if beta is not None else None
110
        rowscale = rowscale.contiguous() if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
111
        colscale = colscale.contiguous() if colscale is not None else None
112
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
Tri Dao's avatar
Tri Dao committed
113
114
            x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
            residual_in_fp32, is_rms_norm
115
        )
Tri Dao's avatar
Tri Dao committed
116
117
118
119
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
        ctx.prenorm = prenorm
120
121
        ctx.dropout_p = dropout_p
        ctx.has_residual = x1 is not None
Tri Dao's avatar
Tri Dao committed
122
123
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
124
        if not return_dmask:
Tri Dao's avatar
Tri Dao committed
125
126
            return (zmat.view(x0.shape) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape)))
127
128
129
130
        else:
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
Tri Dao's avatar
Tri Dao committed
131
132
            return ((zmat.view(x0.shape), dmask) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
133
134
135
136
137

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = dz.contiguous()  # this happens!
Tri Dao's avatar
Tri Dao committed
138
139
140
        dx = args[0].contiguous() if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
        # x0 is None if colscale is None
141
142
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
143
        dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
Tri Dao's avatar
Tri Dao committed
144
145
            dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
            ctx.is_rms_norm
146
147
148
        )
        dx0 = dx0mat.view(x.shape)
        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
Tri Dao's avatar
Tri Dao committed
149
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
150
151
        return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None,
                None, None, None, None)
152
153


154
155
156
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
157
158
                rowscale_const, out_numrows, residual_in_fp32=False,
                prenorm=False, is_rms_norm=False, return_dmask=False):
159
160
161
        x0 = x0.contiguous()
        x1 = x1.contiguous() if x1 is not None else None
        gamma = gamma.contiguous()
Tri Dao's avatar
Tri Dao committed
162
        beta = beta.contiguous() if beta is not None else None
163
164
165
        colscale = colscale.contiguous() if colscale is not None else None
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
            x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
166
            rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
167
168
169
170
171
172
173
174
175
176
177
        )
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        x_shape = (-1, *x0.shape[1:])
        ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale,
                              x0_subset, out_subset)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.rowscale_const = rowscale_const
        ctx.x0_numrows = x0.shape[:-1].numel()
        ctx.has_residual = x1 is not None
Tri Dao's avatar
Tri Dao committed
178
179
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        z_shape = (-1, *x0.shape[1:])
        if not return_dmask:
            return (zmat.view(z_shape) if not prenorm
                    else (zmat.view(z_shape), xmat.view(x0.shape)))
        else:
            z = zmat.view(z_shape)
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
            return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = dz.contiguous()  # this happens!
        dx = args[0].contiguous() if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
        # x0 is None if colscale is None
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
        dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
            dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
Tri Dao's avatar
Tri Dao committed
202
            ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
203
204
205
206
        )
        dx0 = dx0mat.view(-1, *x.shape[1:])
        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
207
208
        return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None,
                None, None, None, None, None, None, None)
209
210


Tri Dao's avatar
Tri Dao committed
211
212
213
214
def layer_norm(x, weight, bias, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)


Tri Dao's avatar
Tri Dao committed
215
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
216
217
218
219
220
                           prenorm=False, residual_in_fp32=False,
                           return_dropout_mask=False):
    """residual_in_fp32 only has an effect if x1 is None.
    Otherwise residual dtype is x1.dtype.
    """
Tri Dao's avatar
Tri Dao committed
221
222
    return DropoutAddLayerNormFn.apply(
        x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
Tri Dao's avatar
Tri Dao committed
223
        False, return_dropout_mask
Tri Dao's avatar
Tri Dao committed
224
    )
225
226


227
228
229
230
231
232
233
234
235
def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
    """residual_in_fp32 only has an effect if x1 is None.
    Otherwise residual dtype is x1.dtype.
    """
    return DropoutAddLayerNormSubsetFn.apply(
        x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
236
        rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
237
238
239
    )


240
class DropoutAddLayerNorm(torch.nn.Module):
241
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.epsilon = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

    def forward(self, x0, x1=None):
        return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
                                      self.p if self.training else 0.0, self.epsilon,
                                      prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)