block.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2022, Tri Dao.

from typing import Optional
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from torchvision.ops import StochasticDepth

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

21
22
23
24
25
try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
    dropout_add_layer_norm_parallel_residual = None

26
27
28
29

class Block(nn.Module):

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
Tri Dao's avatar
Tri Dao committed
30
31
32
                 dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
                 drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
                 residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
Tri Dao's avatar
Tri Dao committed
33
        """
Tri Dao's avatar
Tri Dao committed
34
35
36
37
38
39
40
41
42
43
44
45
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).

        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.

Tri Dao's avatar
Tri Dao committed
46
47
48
49
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connection.
        """
50
51
52
        super().__init__()
        self.prenorm = prenorm
        self.fused_dropout_add_ln = fused_dropout_add_ln
Tri Dao's avatar
Tri Dao committed
53
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
54
55
56
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
57
58
59
60
61
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
Tri Dao's avatar
Tri Dao committed
62
63
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode='row')
64
65
66
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
Tri Dao's avatar
Tri Dao committed
67
68
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode='row')
69
70
71
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
72
            assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
73
74
            assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)

75
76
77
78
79
80
        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

81
82
83
84
85
86
87
        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
88
89
90
91
92
93
94
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._shared_params = True
95

96
    def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
97
                mixer_subset=None, mixer_kwargs=None):
98
99
100
101
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
102
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
103
104
105
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
106
107
108
        """
        if self.prenorm:
            if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
109
110
                dropped = self.drop_path1(self.dropout1(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
111
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
Tri Dao's avatar
Tri Dao committed
112
113
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
114
115
116
117
118
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(torch.ones(
Tri Dao's avatar
Tri Dao committed
119
120
                        hidden_states.shape[:-1], device=hidden_states.device,
                        dtype=hidden_states.dtype)
121
122
                    )
                hidden_states, residual = dropout_add_layer_norm(
Tri Dao's avatar
Tri Dao committed
123
                    hidden_states, residual, self.norm1.weight, self.norm1.bias,
124
                    self.dropout1.p if self.training else 0.0, self.norm1.eps,
Tri Dao's avatar
Tri Dao committed
125
                    rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
126
                )
127
128
            if mixer_kwargs is None:
                mixer_kwargs = {}
129
130
            if mixer_subset is not None:
                mixer_kwargs['mixer_subset'] = mixer_subset
131
132
133
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
134
135
            if not isinstance(self.mlp, nn.Identity):
                if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
136
137
                    dropped = self.drop_path2(self.dropout2(hidden_states))
                    residual = (dropped + residual) if residual is not None else dropped
138
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
Tri Dao's avatar
Tri Dao committed
139
140
                    if self.residual_in_fp32:
                        residual = residual.to(torch.float32)
141
142
143
144
145
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(torch.ones(
Tri Dao's avatar
Tri Dao committed
146
147
                            hidden_states.shape[:-1], device=hidden_states.device,
                            dtype=hidden_states.dtype)
148
149
                        )
                    hidden_states, residual = dropout_add_layer_norm(
Tri Dao's avatar
Tri Dao committed
150
                        hidden_states, residual, self.norm2.weight, self.norm2.bias,
151
                        self.dropout2.p if self.training else 0.0, self.norm2.eps,
Tri Dao's avatar
Tri Dao committed
152
                        rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
153
                    )
Tri Dao's avatar
Tri Dao committed
154
                hidden_states = self.mlp(hidden_states)
155
156
157
            return hidden_states, residual
        else:
            assert residual is None
Tri Dao's avatar
Tri Dao committed
158
159
160
161
162
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            if not self.fused_dropout_add_ln:
                hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
                                            + hidden_states).to(dtype=self.norm1.weight.dtype))
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(torch.ones(
                        mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
                    )
                hidden_states = dropout_add_layer_norm(
                    mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
                    self.dropout1.p if self.training else 0.0, self.norm1.eps,
                    rowscale=rowscale1, prenorm=False
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
Tri Dao's avatar
Tri Dao committed
180
181
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
                if not self.fused_dropout_add_ln:
                    hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
                                                + hidden_states).to(dtype=self.norm2.weight.dtype))
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(torch.ones(
                            mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
                        )
                    hidden_states = dropout_add_layer_norm(
                        mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
                        self.dropout2.p if self.training else 0.0, self.norm2.eps,
                        rowscale=rowscale2, prenorm=False
                    )
            return hidden_states
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


class ParallelBlock(nn.Module):
    """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
    and PaLM.
    """

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
                 dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
                 tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
                 sequence_parallel=False, mark_shared_params=False):
        """
        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA / MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
        the hidden_states (output1 of the MHA / MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.tied_norm = tied_norm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.residual_in_fp32 = residual_in_fp32
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        self.dropout2 = dropout_cls(resid_dropout2)
        if not self.tied_norm:
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
236
            assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
Tri Dao's avatar
Tri Dao committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)

        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._shared_params = True

    def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
                residual: Optional[Tensor] = None, mixer_kwargs=None):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states1: the output of the previous attention (mixer) or embedding layer.
            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
            residual.
        """
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        if not self.fused_dropout_add_ln:
            dropped1 = self.dropout1(hidden_states1)
            # For the very 1st block, we only want 1 dropout, not two different dropouts
            if hidden_states2 is not None:
                dropped2 = self.dropout2(hidden_states2)
                residual = ((residual + dropped1 + dropped2)
                            if residual is not None else dropped1 + dropped2)
            else:
                residual = (residual + dropped1) if residual is not None else dropped1
            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                              if not self.tied_norm else hidden_states1)
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
Tri Dao's avatar
Tri Dao committed
283
        else:
284
285
286
287
288
289
290
291
292
            weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
                              if not self.tied_norm else (None, None))
            hidden_states1, hidden_states2, residual = dropout_add_layer_norm_parallel_residual(
                hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
                weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
                prenorm=True, residual_in_fp32=self.residual_in_fp32
            )
            if self.tied_norm:
                hidden_states2 = hidden_states1
Tri Dao's avatar
Tri Dao committed
293
294
295
296
297
        if mixer_kwargs is None:
            mixer_kwargs = {}
        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
        hidden_states2 = self.mlp(hidden_states2)
        return hidden_states1, hidden_states2, residual