block.py 17.6 KB
Newer Older
1
2
3
# Copyright (c) 2022, Tri Dao.

from functools import partial
Tri Dao's avatar
Tri Dao committed
4
from typing import Optional
5
6
7
8
9

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
Tri Dao's avatar
Tri Dao committed
10
from torchvision.ops import StochasticDepth
11
12
13
14
15
16
17
18
19

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

20
21
22
23
24
try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

Tri Dao's avatar
Tri Dao committed
25
26
27
try:
    from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
28
    RMSNorm, dropout_add_rms_norm = None, None
Tri Dao's avatar
Tri Dao committed
29
30
31
32
33
34

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
    dropout_add_rms_norm_parallel_residual = None

35
36

class Block(nn.Module):
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        prenorm=True,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        drop_path1=0.0,
        drop_path2=0.0,
        fused_dropout_add_ln=False,
        return_residual=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):
Tri Dao's avatar
Tri Dao committed
55
        """
Tri Dao's avatar
Tri Dao committed
56
57
58
59
60
61
62
63
64
65
66
67
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).

        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.

Tri Dao's avatar
Tri Dao committed
68
69
70
71
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connection.
        """
72
73
74
        super().__init__()
        self.prenorm = prenorm
        self.fused_dropout_add_ln = fused_dropout_add_ln
Tri Dao's avatar
Tri Dao committed
75
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
76
77
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
Tri Dao's avatar
Tri Dao committed
78
            assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
79
80
81
82
83
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
Tri Dao's avatar
Tri Dao committed
84
        self.dropout1 = dropout_cls(resid_dropout1)
Tri Dao's avatar
Tri Dao committed
85
        self.drop_path1 = StochasticDepth(drop_path1, mode="row")
86
87
88
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
Tri Dao's avatar
Tri Dao committed
89
            self.dropout2 = dropout_cls(resid_dropout2)
Tri Dao's avatar
Tri Dao committed
90
            self.drop_path2 = StochasticDepth(drop_path2, mode="row")
91
92
93
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
            assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed"
            assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed"
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )
99

100
101
102
103
104
105
        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

106
107
108
109
        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
Tri Dao's avatar
Tri Dao committed
110
            if hasattr(self, "norm2"):
111
112
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
113
114
115
116
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
Tri Dao's avatar
Tri Dao committed
117
            if hasattr(self, "norm2"):
118
119
                for p in self.norm2.parameters():
                    p._shared_params = True
120

121
122
123
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

Tri Dao's avatar
Tri Dao committed
124
125
126
127
128
129
130
    def forward(
        self,
        hidden_states: Tensor,
        residual: Optional[Tensor] = None,
        mixer_subset=None,
        mixer_kwargs=None,
    ):
131
132
133
134
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
135
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
136
137
138
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
139
        """
Tri Dao's avatar
Tri Dao committed
140
141
142
143
144
        fused_add_norm_fn = (
            dropout_add_rms_norm
            if RMSNorm and isinstance(self.norm1, RMSNorm)
            else dropout_add_layer_norm
        )
145
146
        if self.prenorm:
            if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
147
148
                dropped = self.drop_path1(self.dropout1(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
149
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
Tri Dao's avatar
Tri Dao committed
150
151
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
152
            else:
Tri Dao's avatar
Tri Dao committed
153
154
155
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
Tri Dao's avatar
Tri Dao committed
156
157
158
159
160
161
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            hidden_states.shape[:-1],
                            device=hidden_states.device,
                            dtype=hidden_states.dtype,
                        )
Tri Dao's avatar
Tri Dao committed
162
                    )
Tri Dao's avatar
Tri Dao committed
163
                hidden_states, residual = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
164
165
166
167
168
169
170
171
172
                    hidden_states,
                    residual,
                    self.norm1.weight,
                    self.norm1.bias,
                    self.dropout1.p if self.training else 0.0,
                    self.norm1.eps,
                    rowscale=rowscale1,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
173
                )
174
175
            if mixer_kwargs is None:
                mixer_kwargs = {}
176
            if mixer_subset is not None:
Tri Dao's avatar
Tri Dao committed
177
                mixer_kwargs["mixer_subset"] = mixer_subset
178
179
180
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
181
182
            if not isinstance(self.mlp, nn.Identity):
                if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
183
184
                    dropped = self.drop_path2(self.dropout2(hidden_states))
                    residual = (dropped + residual) if residual is not None else dropped
185
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
Tri Dao's avatar
Tri Dao committed
186
187
                    if self.residual_in_fp32:
                        residual = residual.to(torch.float32)
188
                else:
Tri Dao's avatar
Tri Dao committed
189
190
191
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
Tri Dao's avatar
Tri Dao committed
192
193
194
195
196
197
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                hidden_states.shape[:-1],
                                device=hidden_states.device,
                                dtype=hidden_states.dtype,
                            )
Tri Dao's avatar
Tri Dao committed
198
                        )
Tri Dao's avatar
Tri Dao committed
199
                    hidden_states, residual = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
200
201
202
203
204
205
206
207
208
                        hidden_states,
                        residual,
                        self.norm2.weight,
                        self.norm2.bias,
                        self.dropout2.p if self.training else 0.0,
                        self.norm2.eps,
                        rowscale=rowscale2,
                        prenorm=True,
                        residual_in_fp32=self.residual_in_fp32,
209
                    )
Tri Dao's avatar
Tri Dao committed
210
                hidden_states = self.mlp(hidden_states)
211
212
213
            return hidden_states, residual
        else:
            assert residual is None
Tri Dao's avatar
Tri Dao committed
214
215
216
217
218
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out
219
            if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
220
221
222
223
224
                hidden_states = self.norm1(
                    (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
                        dtype=self.norm1.weight.dtype
                    )
                )
225
226
227
228
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
Tri Dao's avatar
Tri Dao committed
229
230
231
232
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
                        )
233
                    )
Tri Dao's avatar
Tri Dao committed
234
                hidden_states = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
235
236
237
238
239
240
241
242
                    mixer_out,
                    hidden_states,
                    self.norm1.weight,
                    self.norm1.bias,
                    self.dropout1.p if self.training else 0.0,
                    self.norm1.eps,
                    rowscale=rowscale1,
                    prenorm=False,
243
244
245
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
Tri Dao's avatar
Tri Dao committed
246
247
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out
248
                if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
249
250
251
252
253
                    hidden_states = self.norm2(
                        (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
                            dtype=self.norm2.weight.dtype
                        )
                    )
254
255
256
257
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
Tri Dao's avatar
Tri Dao committed
258
259
260
261
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
                            )
262
                        )
Tri Dao's avatar
Tri Dao committed
263
                    hidden_states = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
264
265
266
267
268
269
270
271
                        mlp_out,
                        hidden_states,
                        self.norm2.weight,
                        self.norm2.bias,
                        self.dropout2.p if self.training else 0.0,
                        self.norm2.eps,
                        rowscale=rowscale2,
                        prenorm=False,
272
273
                    )
            return hidden_states
Tri Dao's avatar
Tri Dao committed
274
275
276
277
278
279
280


class ParallelBlock(nn.Module):
    """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
    and PaLM.
    """

Tri Dao's avatar
Tri Dao committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        tied_norm=False,
        fused_dropout_add_ln=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        """
        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA / MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
        the hidden_states (output1 of the MHA / MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.tied_norm = tied_norm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.residual_in_fp32 = residual_in_fp32
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        self.dropout2 = dropout_cls(resid_dropout2)
        if not self.tied_norm:
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
323
324
325
326
327
328
329
330
331
            assert (
                dropout_add_layer_norm_parallel_residual is not None
            ), "dropout_layer_norm is not installed"
            assert (
                dropout_add_rms_norm_parallel_residual is not None
            ), "dropout_layer_norm is not installed"
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )
Tri Dao's avatar
Tri Dao committed
332
333
334
335
336
337
338
339
340
341
342

        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
Tri Dao's avatar
Tri Dao committed
343
            if hasattr(self, "norm2"):
Tri Dao's avatar
Tri Dao committed
344
345
346
347
348
349
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
Tri Dao's avatar
Tri Dao committed
350
            if hasattr(self, "norm2"):
Tri Dao's avatar
Tri Dao committed
351
352
353
                for p in self.norm2.parameters():
                    p._shared_params = True

354
355
356
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

Tri Dao's avatar
Tri Dao committed
357
358
359
360
361
362
363
    def forward(
        self,
        hidden_states1: Tensor,
        hidden_states2: Optional[Tensor] = None,
        residual: Optional[Tensor] = None,
        mixer_kwargs=None,
    ):
Tri Dao's avatar
Tri Dao committed
364
365
366
367
368
369
370
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states1: the output of the previous attention (mixer) or embedding layer.
            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
            residual.
        """
Tri Dao's avatar
Tri Dao committed
371
372
        # TODO: Ideally we should only do the allgather / allreduce once for
        # the Linear to MLP & Attention
Tri Dao's avatar
Tri Dao committed
373
374
375
376
377
        fused_add_norm_fn = (
            dropout_add_rms_norm_parallel_residual
            if isinstance(self.norm1, RMSNorm)
            else dropout_add_layer_norm_parallel_residual
        )
378
379
380
381
382
        if not self.fused_dropout_add_ln:
            dropped1 = self.dropout1(hidden_states1)
            # For the very 1st block, we only want 1 dropout, not two different dropouts
            if hidden_states2 is not None:
                dropped2 = self.dropout2(hidden_states2)
Tri Dao's avatar
Tri Dao committed
383
384
385
386
387
                residual = (
                    (residual + dropped1 + dropped2)
                    if residual is not None
                    else dropped1 + dropped2
                )
388
389
390
            else:
                residual = (residual + dropped1) if residual is not None else dropped1
            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
Tri Dao's avatar
Tri Dao committed
391
392
393
394
395
            hidden_states2 = (
                self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if not self.tied_norm
                else hidden_states1
            )
396
397
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
Tri Dao's avatar
Tri Dao committed
398
        else:
Tri Dao's avatar
Tri Dao committed
399
400
401
            weight2, bias2 = (
                (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
            )
Tri Dao's avatar
Tri Dao committed
402
            hidden_states1, hidden_states2, residual = fused_add_norm_fn(
Tri Dao's avatar
Tri Dao committed
403
404
405
406
407
408
409
410
411
412
413
                hidden_states1,
                hidden_states2,
                residual,
                self.norm1.weight,
                self.norm1.bias,
                weight2,
                bias2,
                self.dropout1.p if self.training else 0.0,
                self.norm1.eps,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
414
415
416
            )
            if self.tied_norm:
                hidden_states2 = hidden_states1
Tri Dao's avatar
Tri Dao committed
417
418
419
420
421
        if mixer_kwargs is None:
            mixer_kwargs = {}
        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
        hidden_states2 = self.mlp(hidden_states2)
        return hidden_states1, hidden_states2, residual