block.py 16.9 KB
Newer Older
1
# Copyright (c) 2024, Tri Dao.
2
3

from functools import partial
Tri Dao's avatar
Tri Dao committed
4
from typing import Optional
5
6
7
8
9

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
Tri Dao's avatar
Tri Dao committed
10
from torchvision.ops import StochasticDepth
11
12
13
14
15

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

try:
16
    from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
17
except ImportError:
18
    layer_norm_fn, RMSNorm = None, None
Tri Dao's avatar
Tri Dao committed
19

20
21

class Block(nn.Module):
Tri Dao's avatar
Tri Dao committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        prenorm=True,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        drop_path1=0.0,
        drop_path2=0.0,
        fused_dropout_add_ln=False,
        return_residual=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):
Tri Dao's avatar
Tri Dao committed
40
        """
Tri Dao's avatar
Tri Dao committed
41
42
43
44
45
46
47
48
49
50
51
52
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).

        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.

Tri Dao's avatar
Tri Dao committed
53
54
55
56
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connection.
        """
57
58
59
        super().__init__()
        self.prenorm = prenorm
        self.fused_dropout_add_ln = fused_dropout_add_ln
Tri Dao's avatar
Tri Dao committed
60
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
61
62
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
Tri Dao's avatar
Tri Dao committed
63
            assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
64
65
66
67
68
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
Tri Dao's avatar
Tri Dao committed
69
        self.dropout1 = dropout_cls(resid_dropout1)
Tri Dao's avatar
Tri Dao committed
70
        self.drop_path1 = StochasticDepth(drop_path1, mode="row")
71
72
73
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
Tri Dao's avatar
Tri Dao committed
74
            self.dropout2 = dropout_cls(resid_dropout2)
Tri Dao's avatar
Tri Dao committed
75
            self.drop_path2 = StochasticDepth(drop_path2, mode="row")
76
77
78
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
79
            assert layer_norm_fn is not None, "Triton is not installed"
Tri Dao's avatar
Tri Dao committed
80
81
82
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )
83

84
85
86
87
88
89
        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

90
91
92
93
        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
Tri Dao's avatar
Tri Dao committed
94
            if hasattr(self, "norm2"):
95
96
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
97
98
99
100
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
Tri Dao's avatar
Tri Dao committed
101
            if hasattr(self, "norm2"):
102
103
                for p in self.norm2.parameters():
                    p._shared_params = True
104

105
106
107
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
    def forward(
        self,
        hidden_states: Tensor,
        residual: Optional[Tensor] = None,
        mixer_subset=None,
        mixer_kwargs=None,
    ):
115
116
117
118
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
119
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
120
121
122
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
123
124
125
        """
        if self.prenorm:
            if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
126
127
                dropped = self.drop_path1(self.dropout1(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
128
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
Tri Dao's avatar
Tri Dao committed
129
130
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
131
            else:
Tri Dao's avatar
Tri Dao committed
132
133
134
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
Tri Dao's avatar
Tri Dao committed
135
136
137
138
139
140
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            hidden_states.shape[:-1],
                            device=hidden_states.device,
                            dtype=hidden_states.dtype,
                        )
Tri Dao's avatar
Tri Dao committed
141
                    )
142
                hidden_states, residual = layer_norm_fn(
Tri Dao's avatar
Tri Dao committed
143
144
145
                    hidden_states,
                    self.norm1.weight,
                    self.norm1.bias,
146
147
148
                    residual=residual,
                    eps=self.norm1.eps,
                    dropout_p=self.dropout1.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
149
150
151
                    rowscale=rowscale1,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
152
                    is_rms_norm=isinstance(self.norm1, RMSNorm)
153
                )
154
155
            if mixer_kwargs is None:
                mixer_kwargs = {}
156
            if mixer_subset is not None:
Tri Dao's avatar
Tri Dao committed
157
                mixer_kwargs["mixer_subset"] = mixer_subset
158
159
160
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
161
162
            if not isinstance(self.mlp, nn.Identity):
                if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
163
164
                    dropped = self.drop_path2(self.dropout2(hidden_states))
                    residual = (dropped + residual) if residual is not None else dropped
165
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
Tri Dao's avatar
Tri Dao committed
166
167
                    if self.residual_in_fp32:
                        residual = residual.to(torch.float32)
168
                else:
Tri Dao's avatar
Tri Dao committed
169
170
171
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
Tri Dao's avatar
Tri Dao committed
172
173
174
175
176
177
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                hidden_states.shape[:-1],
                                device=hidden_states.device,
                                dtype=hidden_states.dtype,
                            )
Tri Dao's avatar
Tri Dao committed
178
                        )
179
                    hidden_states, residual = layer_norm_fn(
Tri Dao's avatar
Tri Dao committed
180
181
182
                        hidden_states,
                        self.norm2.weight,
                        self.norm2.bias,
183
184
185
                        residual=residual,
                        eps=self.norm2.eps,
                        dropout_p=self.dropout2.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
186
187
188
                        rowscale=rowscale2,
                        prenorm=True,
                        residual_in_fp32=self.residual_in_fp32,
189
                        is_rms_norm=isinstance(self.norm2, RMSNorm)
190
                    )
Tri Dao's avatar
Tri Dao committed
191
                hidden_states = self.mlp(hidden_states)
192
193
194
            return hidden_states, residual
        else:
            assert residual is None
Tri Dao's avatar
Tri Dao committed
195
196
197
198
199
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out
200
            if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
201
202
203
204
205
                hidden_states = self.norm1(
                    (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
                        dtype=self.norm1.weight.dtype
                    )
                )
206
207
208
209
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
Tri Dao's avatar
Tri Dao committed
210
211
212
213
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
                        )
214
                    )
215
                hidden_states = layer_norm_fn(
Tri Dao's avatar
Tri Dao committed
216
217
218
                    mixer_out,
                    self.norm1.weight,
                    self.norm1.bias,
219
220
221
                    residual=hidden_states,
                    eps=self.norm1.eps,
                    dropout_p=self.dropout1.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
222
223
                    rowscale=rowscale1,
                    prenorm=False,
224
                    is_rms_norm=isinstance(self.norm1, RMSNorm)
225
226
227
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
Tri Dao's avatar
Tri Dao committed
228
229
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out
230
                if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
                    hidden_states = self.norm2(
                        (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
                            dtype=self.norm2.weight.dtype
                        )
                    )
236
237
238
239
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
Tri Dao's avatar
Tri Dao committed
240
241
242
243
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
                            )
244
                        )
245
                    hidden_states = layer_norm_fn(
Tri Dao's avatar
Tri Dao committed
246
247
248
                        mlp_out,
                        self.norm2.weight,
                        self.norm2.bias,
249
250
251
                        residual=hidden_states,
                        eps=self.norm2.eps,
                        dropout_p=self.dropout2.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
252
253
                        rowscale=rowscale2,
                        prenorm=False,
254
                        is_rms_norm=isinstance(self.norm2, RMSNorm)
255
256
                    )
            return hidden_states
Tri Dao's avatar
Tri Dao committed
257
258
259
260
261
262
263


class ParallelBlock(nn.Module):
    """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
    and PaLM.
    """

Tri Dao's avatar
Tri Dao committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        tied_norm=False,
        fused_dropout_add_ln=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):
Tri Dao's avatar
Tri Dao committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        """
        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA / MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
        the hidden_states (output1 of the MHA / MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.tied_norm = tied_norm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.residual_in_fp32 = residual_in_fp32
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        self.dropout2 = dropout_cls(resid_dropout2)
        if not self.tied_norm:
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
306
            assert layer_norm_fn is not None, "Triton is not installed"
Tri Dao's avatar
Tri Dao committed
307
308
309
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )
Tri Dao's avatar
Tri Dao committed
310
311
312
313
314
315
316
317
318
319
320

        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
Tri Dao's avatar
Tri Dao committed
321
            if hasattr(self, "norm2"):
Tri Dao's avatar
Tri Dao committed
322
323
324
325
326
327
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
Tri Dao's avatar
Tri Dao committed
328
            if hasattr(self, "norm2"):
Tri Dao's avatar
Tri Dao committed
329
330
331
                for p in self.norm2.parameters():
                    p._shared_params = True

332
333
334
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

Tri Dao's avatar
Tri Dao committed
335
336
337
338
339
340
341
    def forward(
        self,
        hidden_states1: Tensor,
        hidden_states2: Optional[Tensor] = None,
        residual: Optional[Tensor] = None,
        mixer_kwargs=None,
    ):
Tri Dao's avatar
Tri Dao committed
342
343
344
345
346
347
348
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states1: the output of the previous attention (mixer) or embedding layer.
            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
            residual.
        """
Tri Dao's avatar
Tri Dao committed
349
350
        # TODO: Ideally we should only do the allgather / allreduce once for
        # the Linear to MLP & Attention
351
352
353
354
355
        if not self.fused_dropout_add_ln:
            dropped1 = self.dropout1(hidden_states1)
            # For the very 1st block, we only want 1 dropout, not two different dropouts
            if hidden_states2 is not None:
                dropped2 = self.dropout2(hidden_states2)
Tri Dao's avatar
Tri Dao committed
356
357
358
359
360
                residual = (
                    (residual + dropped1 + dropped2)
                    if residual is not None
                    else dropped1 + dropped2
                )
361
362
363
            else:
                residual = (residual + dropped1) if residual is not None else dropped1
            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
Tri Dao's avatar
Tri Dao committed
364
365
366
367
368
            hidden_states2 = (
                self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if not self.tied_norm
                else hidden_states1
            )
369
370
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
Tri Dao's avatar
Tri Dao committed
371
        else:
Tri Dao's avatar
Tri Dao committed
372
373
374
            weight2, bias2 = (
                (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
            )
375
            hidden_states1, *rest, residual = layer_norm_fn(
Tri Dao's avatar
Tri Dao committed
376
377
378
                hidden_states1,
                self.norm1.weight,
                self.norm1.bias,
379
380
381
382
383
384
                residual=residual,
                x1=hidden_states2,
                weight1=weight2,
                bias1=bias2,
                eps=self.norm1.eps,
                dropout_p=self.dropout1.p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
385
386
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
387
                is_rms_norm=isinstance(self.norm1, RMSNorm)
388
389
390
            )
            if self.tied_norm:
                hidden_states2 = hidden_states1
391
392
            else:
                hidden_states2, = rest
Tri Dao's avatar
Tri Dao committed
393
394
395
396
397
        if mixer_kwargs is None:
            mixer_kwargs = {}
        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
        hidden_states2 = self.mlp(hidden_states2)
        return hidden_states1, hidden_states2, residual