"csrc/layer_norm/static_switch.h" did not exist on "de19de7ab13301e840a8bc483e77be6e424e7b32"
block.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright (c) 2022, Tri Dao.

from typing import Optional
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from torchvision.ops import StochasticDepth

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None


class Block(nn.Module):

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
Tri Dao's avatar
Tri Dao committed
25
26
27
                 dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
                 drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
                 residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
Tri Dao's avatar
Tri Dao committed
28
        """
Tri Dao's avatar
Tri Dao committed
29
30
31
32
33
34
35
36
37
38
39
40
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).

        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.

Tri Dao's avatar
Tri Dao committed
41
42
43
44
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connection.
        """
45
46
47
        super().__init__()
        self.prenorm = prenorm
        self.fused_dropout_add_ln = fused_dropout_add_ln
Tri Dao's avatar
Tri Dao committed
48
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
49
50
51
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
52
53
54
55
56
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
Tri Dao's avatar
Tri Dao committed
57
58
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode='row')
59
60
61
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
Tri Dao's avatar
Tri Dao committed
62
63
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode='row')
64
65
66
67
68
69
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
            assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
            assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)

70
71
72
73
74
75
        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

76
77
78
79
80
81
82
        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
83
84
85
86
87
88
89
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._shared_params = True
90

91
    def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
92
                mixer_subset=None, mixer_kwargs=None):
93
94
95
96
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
97
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
98
99
100
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
101
102
103
        """
        if self.prenorm:
            if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
104
105
                dropped = self.drop_path1(self.dropout1(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
106
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
Tri Dao's avatar
Tri Dao committed
107
108
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
109
110
111
112
113
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(torch.ones(
Tri Dao's avatar
Tri Dao committed
114
115
                        hidden_states.shape[:-1], device=hidden_states.device,
                        dtype=hidden_states.dtype)
116
117
                    )
                hidden_states, residual = dropout_add_layer_norm(
Tri Dao's avatar
Tri Dao committed
118
                    hidden_states, residual, self.norm1.weight, self.norm1.bias,
119
                    self.dropout1.p if self.training else 0.0, self.norm1.eps,
Tri Dao's avatar
Tri Dao committed
120
                    rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
121
                )
122
123
            if mixer_kwargs is None:
                mixer_kwargs = {}
124
125
            if mixer_subset is not None:
                mixer_kwargs['mixer_subset'] = mixer_subset
126
127
128
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
129
130
            if not isinstance(self.mlp, nn.Identity):
                if not self.fused_dropout_add_ln:
Tri Dao's avatar
Tri Dao committed
131
132
                    dropped = self.drop_path2(self.dropout2(hidden_states))
                    residual = (dropped + residual) if residual is not None else dropped
133
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
Tri Dao's avatar
Tri Dao committed
134
135
                    if self.residual_in_fp32:
                        residual = residual.to(torch.float32)
136
137
138
139
140
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(torch.ones(
Tri Dao's avatar
Tri Dao committed
141
142
                            hidden_states.shape[:-1], device=hidden_states.device,
                            dtype=hidden_states.dtype)
143
144
                        )
                    hidden_states, residual = dropout_add_layer_norm(
Tri Dao's avatar
Tri Dao committed
145
                        hidden_states, residual, self.norm2.weight, self.norm2.bias,
146
                        self.dropout2.p if self.training else 0.0, self.norm2.eps,
Tri Dao's avatar
Tri Dao committed
147
                        rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
148
                    )
Tri Dao's avatar
Tri Dao committed
149
                hidden_states = self.mlp(hidden_states)
150
151
152
            return hidden_states, residual
        else:
            assert residual is None
Tri Dao's avatar
Tri Dao committed
153
154
155
156
157
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            if not self.fused_dropout_add_ln:
                hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
                                            + hidden_states).to(dtype=self.norm1.weight.dtype))
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(torch.ones(
                        mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
                    )
                hidden_states = dropout_add_layer_norm(
                    mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
                    self.dropout1.p if self.training else 0.0, self.norm1.eps,
                    rowscale=rowscale1, prenorm=False
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
Tri Dao's avatar
Tri Dao committed
175
176
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                if not self.fused_dropout_add_ln:
                    hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
                                                + hidden_states).to(dtype=self.norm2.weight.dtype))
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(torch.ones(
                            mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
                        )
                    hidden_states = dropout_add_layer_norm(
                        mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
                        self.dropout2.p if self.training else 0.0, self.norm2.eps,
                        rowscale=rowscale2, prenorm=False
                    )
            return hidden_states
Tri Dao's avatar
Tri Dao committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282


class ParallelBlock(nn.Module):
    """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
    and PaLM.
    """

    def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
                 dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
                 tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
                 sequence_parallel=False, mark_shared_params=False):
        """
        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA / MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
        the hidden_states (output1 of the MHA / MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.tied_norm = tied_norm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        assert not self.fused_dropout_add_ln, 'This is not implemented for ParallelBlock yet'
        self.residual_in_fp32 = residual_in_fp32
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        self.dropout2 = dropout_cls(resid_dropout2)
        if not self.tied_norm:
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
            assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
            assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)

        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
        # then the input to each worker in the tensor parallel group will be different.
        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
        # For now this is not an issue because we always use sequence_parallel=True during training
        # and only use sequence_parallel=False during inference.

        # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True
        # Mark the norm parameters as "shared_params" so that we sync their values at init.
        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, 'norm2'):
                for p in self.norm2.parameters():
                    p._shared_params = True

    def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
                residual: Optional[Tensor] = None, mixer_kwargs=None):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states1: the output of the previous attention (mixer) or embedding layer.
            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
            residual.
        """
        dropped1 = self.dropout1(hidden_states1)
        # For the very 1st block, we only want 1 dropout, not two different dropouts
        if hidden_states2 is not None:
            dropped2 = self.dropout2(hidden_states2)
            residual = ((residual + dropped1 + dropped2)
                        if residual is not None else dropped1 + dropped2)
        else:
            residual = (residual + dropped1) if residual is not None else dropped1
        hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
        hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                          if not self.tied_norm else hidden_states1)
        if self.residual_in_fp32:
            residual = residual.to(torch.float32)
        if mixer_kwargs is None:
            mixer_kwargs = {}
        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
        hidden_states2 = self.mlp(hidden_states2)
        return hidden_states1, hidden_states2, residual