block.py 15.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2022, Tri Dao.

from typing import Optional
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from torchvision.ops import StochasticDepth

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

21
22
23
24
25
try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
26
27
28
29
30
31
32
33
34
35
try:
    from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
    RMSNorm, dropout_add_rms_norm = None

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
    dropout_add_rms_norm_parallel_residual = None

36
37
38
39

class Block(nn.Module):

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
Tri Dao's avatar
Tri Dao committed
40
41
42
                 dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
                 drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
                 residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
Tri Dao's avatar
Tri Dao committed
43
        """
Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
49
50
51
52
53
54
55
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).

        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.

Tri Dao's avatar
Tri Dao committed
56
57
58
59
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connection.
        """
60
61
62
        super().__init__()
        self.prenorm = prenorm
        self.fused_dropout_add_ln = fused_dropout_add_ln
Tri Dao's avatar
Tri Dao committed
63
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
64
65
66
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
67
68
69
70
71
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
Tri Dao's avatar
Tri Dao committed
72
73
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode='row')
74
75
76
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
Tri Dao's avatar
Tri Dao committed
77
78
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode='row')
79
80
81
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
82
            assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
Tri Dao's avatar
Tri Dao committed
83
84
85
            assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed'
            assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
                    and isinstance(self.dropout1, nn.Dropout))
86

87
88
89
90
91
92
        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

93
94
95
96
97
98
99
        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
100
101
102
103
104
105
106
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._shared_params = True
107

108
    def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
109
                mixer_subset=None, mixer_kwargs=None):
110
111
112
113
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
114
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
115
116
117
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
118
        """
Tri Dao's avatar
Tri Dao committed
119
120
        fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm)
                             else dropout_add_layer_norm)
121
122
        if self.prenorm:
            if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
123
124
                dropped = self.drop_path1(self.dropout1(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
125
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
Tri Dao's avatar
Tri Dao committed
126
127
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
128
129
130
131
132
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(torch.ones(
Tri Dao's avatar
Tri Dao committed
133
134
                        hidden_states.shape[:-1], device=hidden_states.device,
                        dtype=hidden_states.dtype)
135
                    )
Tri Dao's avatar
Tri Dao committed
136
                hidden_states, residual = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
137
                    hidden_states, residual, self.norm1.weight, self.norm1.bias,
138
                    self.dropout1.p if self.training else 0.0, self.norm1.eps,
Tri Dao's avatar
Tri Dao committed
139
                    rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
140
                )
141
142
            if mixer_kwargs is None:
                mixer_kwargs = {}
143
144
            if mixer_subset is not None:
                mixer_kwargs['mixer_subset'] = mixer_subset
145
146
147
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
148
149
            if not isinstance(self.mlp, nn.Identity):
                if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
150
151
                    dropped = self.drop_path2(self.dropout2(hidden_states))
                    residual = (dropped + residual) if residual is not None else dropped
152
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
Tri Dao's avatar
Tri Dao committed
153
154
                    if self.residual_in_fp32:
                        residual = residual.to(torch.float32)
155
156
157
158
159
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(torch.ones(
Tri Dao's avatar
Tri Dao committed
160
161
                            hidden_states.shape[:-1], device=hidden_states.device,
                            dtype=hidden_states.dtype)
162
                        )
Tri Dao's avatar
Tri Dao committed
163
                    hidden_states, residual = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
164
                        hidden_states, residual, self.norm2.weight, self.norm2.bias,
165
                        self.dropout2.p if self.training else 0.0, self.norm2.eps,
Tri Dao's avatar
Tri Dao committed
166
                        rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
167
                    )
Tri Dao's avatar
Tri Dao committed
168
                hidden_states = self.mlp(hidden_states)
169
170
171
            return hidden_states, residual
        else:
            assert residual is None
Tri Dao's avatar
Tri Dao committed
172
173
174
175
176
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out
177
178
179
180
181
182
183
184
185
186
            if not self.fused_dropout_add_ln:
                hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
                                            + hidden_states).to(dtype=self.norm1.weight.dtype))
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(torch.ones(
                        mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
                    )
Tri Dao's avatar
Tri Dao committed
187
                hidden_states = fused_add_norm_fn(
188
189
190
191
192
193
                    mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
                    self.dropout1.p if self.training else 0.0, self.norm1.eps,
                    rowscale=rowscale1, prenorm=False
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
Tri Dao's avatar
Tri Dao committed
194
195
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out
196
197
198
199
200
201
202
203
204
205
                if not self.fused_dropout_add_ln:
                    hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
                                                + hidden_states).to(dtype=self.norm2.weight.dtype))
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(torch.ones(
                            mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
                        )
Tri Dao's avatar
Tri Dao committed
206
                    hidden_states = fused_add_norm_fn(
207
208
209
210
211
                        mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
                        self.dropout2.p if self.training else 0.0, self.norm2.eps,
                        rowscale=rowscale2, prenorm=False
                    )
            return hidden_states
Tri Dao's avatar
Tri Dao committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249


class ParallelBlock(nn.Module):
    """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
    and PaLM.
    """

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
                 dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
                 tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
                 sequence_parallel=False, mark_shared_params=False):
        """
        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA / MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
        the hidden_states (output1 of the MHA / MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.tied_norm = tied_norm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.residual_in_fp32 = residual_in_fp32
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        self.dropout2 = dropout_cls(resid_dropout2)
        if not self.tied_norm:
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
250
            assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
Tri Dao's avatar
Tri Dao committed
251
252
253
            assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
            assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
                    and isinstance(self.dropout1, nn.Dropout))
Tri Dao's avatar
Tri Dao committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._shared_params = True

    def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
                residual: Optional[Tensor] = None, mixer_kwargs=None):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states1: the output of the previous attention (mixer) or embedding layer.
            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
            residual.
        """
Tri Dao's avatar
Tri Dao committed
285
286
287
        fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
                             if isinstance(self.norm1, RMSNorm)
                             else dropout_add_layer_norm_parallel_residual)
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        if not self.fused_dropout_add_ln:
            dropped1 = self.dropout1(hidden_states1)
            # For the very 1st block, we only want 1 dropout, not two different dropouts
            if hidden_states2 is not None:
                dropped2 = self.dropout2(hidden_states2)
                residual = ((residual + dropped1 + dropped2)
                            if residual is not None else dropped1 + dropped2)
            else:
                residual = (residual + dropped1) if residual is not None else dropped1
            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                              if not self.tied_norm else hidden_states1)
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
Tri Dao's avatar
Tri Dao committed
302
        else:
303
304
            weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
                              if not self.tied_norm else (None, None))
Tri Dao's avatar
Tri Dao committed
305
            hidden_states1, hidden_states2, residual = fused_add_norm_fn(
306
307
308
309
310
311
                hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
                weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
                prenorm=True, residual_in_fp32=self.residual_in_fp32
            )
            if self.tied_norm:
                hidden_states2 = hidden_states1
Tri Dao's avatar
Tri Dao committed
312
313
314
315
316
        if mixer_kwargs is None:
            mixer_kwargs = {}
        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
        hidden_states2 = self.mlp(hidden_states2)
        return hidden_states1, hidden_states2, residual