"vscode:/vscode.git/clone" did not exist on "eb84e5d5ec2e5477bcffd7196c52dc7a98bfd5eb"
attention.py 18.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
15
from typing import Callable, Optional
16
17

import torch
Patrick von Platen's avatar
Patrick von Platen committed
18
import torch.nn.functional as F
19
20
from torch import nn

Will Berman's avatar
Will Berman committed
21
from ..utils.import_utils import is_xformers_available
22
from .cross_attention import CrossAttention
Kashif Rasul's avatar
Kashif Rasul committed
23
from .embeddings import CombinedTimestepLabelEmbeddings
24
25
26
27
28
29
30
31


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None

32

33
class AttentionBlock(nn.Module):
Patrick von Platen's avatar
Patrick von Platen committed
34
35
36
37
    """
    An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
    to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Kashif Rasul's avatar
Kashif Rasul committed
38
39
40
    Uses three q, k, v linear layers to compute attention.

    Parameters:
Will Berman's avatar
Will Berman committed
41
42
        channels (`int`): The number of channels in the input and output.
        num_head_channels (`int`, *optional*):
Kashif Rasul's avatar
Kashif Rasul committed
43
            The number of channels in each head. If None, then `num_heads` = 1.
Will Berman's avatar
Will Berman committed
44
45
46
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
        rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
        eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
Patrick von Platen's avatar
Patrick von Platen committed
47
48
    """

Will Berman's avatar
Will Berman committed
49
50
    # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore

Patrick von Platen's avatar
Patrick von Platen committed
51
52
    def __init__(
        self,
Kashif Rasul's avatar
Kashif Rasul committed
53
54
        channels: int,
        num_head_channels: Optional[int] = None,
Will Berman's avatar
Will Berman committed
55
        norm_num_groups: int = 32,
Kashif Rasul's avatar
Kashif Rasul committed
56
57
        rescale_output_factor: float = 1.0,
        eps: float = 1e-5,
Patrick von Platen's avatar
Patrick von Platen committed
58
59
60
61
    ):
        super().__init__()
        self.channels = channels

Patrick von Platen's avatar
Patrick von Platen committed
62
        self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
Patrick von Platen's avatar
Patrick von Platen committed
63
        self.num_head_size = num_head_channels
Will Berman's avatar
Will Berman committed
64
        self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67
68
69
70
71

        # define q,k,v as linear layers
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)

        self.rescale_output_factor = rescale_output_factor
Patrick von Platen's avatar
Patrick von Platen committed
72
        self.proj_attn = nn.Linear(channels, channels, 1)
Patrick von Platen's avatar
Patrick von Platen committed
73

74
        self._use_memory_efficient_attention_xformers = False
75
        self._attention_op = None
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def reshape_heads_to_batch_dim(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    def reshape_batch_dim_to_heads(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

91
92
93
    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ):
94
95
96
        if use_memory_efficient_attention_xformers:
            if not is_xformers_available():
                raise ModuleNotFoundError(
Patrick von Platen's avatar
Patrick von Platen committed
97
98
                    "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
                    " xformers",
99
100
101
102
                    name="xformers",
                )
            elif not torch.cuda.is_available():
                raise ValueError(
Patrick von Platen's avatar
Patrick von Platen committed
103
104
                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
                    " only available for GPU "
105
                )
106
107
108
109
110
111
112
113
114
115
116
            else:
                try:
                    # Make sure we can run the memory efficient attention
                    _ = xformers.ops.memory_efficient_attention(
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                    )
                except Exception as e:
                    raise e
        self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
117
        self._attention_op = attention_op
118

Patrick von Platen's avatar
Patrick von Platen committed
119
120
121
122
123
124
    def forward(self, hidden_states):
        residual = hidden_states
        batch, channel, height, width = hidden_states.shape

        # norm
        hidden_states = self.group_norm(hidden_states)
125

Patrick von Platen's avatar
Patrick von Platen committed
126
127
128
129
130
131
132
        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

        # proj to q, k, v
        query_proj = self.query(hidden_states)
        key_proj = self.key(hidden_states)
        value_proj = self.value(hidden_states)

133
        scale = 1 / math.sqrt(self.channels / self.num_heads)
Patrick von Platen's avatar
Patrick von Platen committed
134

Suraj Patil's avatar
Suraj Patil committed
135
136
137
138
        query_proj = self.reshape_heads_to_batch_dim(query_proj)
        key_proj = self.reshape_heads_to_batch_dim(key_proj)
        value_proj = self.reshape_heads_to_batch_dim(value_proj)

139
140
        if self._use_memory_efficient_attention_xformers:
            # Memory efficient attention
141
142
143
            hidden_states = xformers.ops.memory_efficient_attention(
                query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
            )
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            hidden_states = hidden_states.to(query_proj.dtype)
        else:
            attention_scores = torch.baddbmm(
                torch.empty(
                    query_proj.shape[0],
                    query_proj.shape[1],
                    key_proj.shape[1],
                    dtype=query_proj.dtype,
                    device=query_proj.device,
                ),
                query_proj,
                key_proj.transpose(-1, -2),
                beta=0,
                alpha=scale,
            )
            attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
            hidden_states = torch.bmm(attention_probs, value_proj)
Patrick von Platen's avatar
Patrick von Platen committed
161

Suraj Patil's avatar
Suraj Patil committed
162
163
        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
Patrick von Platen's avatar
Patrick von Platen committed
164
165

        # compute next hidden_states
166
        hidden_states = self.proj_attn(hidden_states)
Will Berman's avatar
Will Berman committed
167

Patrick von Platen's avatar
Patrick von Platen committed
168
169
170
171
172
173
        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

        # res connect and rescale
        hidden_states = (hidden_states + residual) / self.rescale_output_factor
        return hidden_states

Patrick von Platen's avatar
Patrick von Platen committed
174

Patrick von Platen's avatar
Patrick von Platen committed
175
class BasicTransformerBlock(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
176
177
178
179
    r"""
    A basic Transformer block.

    Parameters:
Will Berman's avatar
Will Berman committed
180
181
182
183
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Will Berman's avatar
Will Berman committed
184
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
Will Berman's avatar
Will Berman committed
185
186
187
188
189
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
Kashif Rasul's avatar
Kashif Rasul committed
190
191
192
193
194
    """

    def __init__(
        self,
        dim: int,
Will Berman's avatar
Will Berman committed
195
196
        num_attention_heads: int,
        attention_head_dim: int,
Kashif Rasul's avatar
Kashif Rasul committed
197
        dropout=0.0,
Will Berman's avatar
Will Berman committed
198
199
200
201
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
202
        only_cross_attention: bool = False,
203
        upcast_attention: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
204
205
206
        norm_elementwise_affine: bool = True,
        norm_type: str = "layer_norm",
        final_dropout: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
207
    ):
Patrick von Platen's avatar
Patrick von Platen committed
208
        super().__init__()
209
        self.only_cross_attention = only_cross_attention
Kashif Rasul's avatar
Kashif Rasul committed
210
211
212
213
214
215
216
217
218

        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"

        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
            raise ValueError(
                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
            )
219
220

        # 1. Self-Attn
Patrick von Platen's avatar
Patrick von Platen committed
221
        self.attn1 = CrossAttention(
Will Berman's avatar
Will Berman committed
222
223
224
225
226
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
227
            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
228
            upcast_attention=upcast_attention,
229
230
        )

Kashif Rasul's avatar
Kashif Rasul committed
231
        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
Will Berman's avatar
Will Berman committed
232

233
234
235
236
237
238
239
240
241
        # 2. Cross-Attn
        if cross_attention_dim is not None:
            self.attn2 = CrossAttention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
242
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
243
            )  # is self-attn if encoder_hidden_states is none
Will Berman's avatar
Will Berman committed
244
        else:
245
246
            self.attn2 = None

Kashif Rasul's avatar
Kashif Rasul committed
247
248
249
250
251
252
        if self.use_ada_layer_norm:
            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
        elif self.use_ada_layer_norm_zero:
            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
        else:
            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
253
254

        if cross_attention_dim is not None:
Kashif Rasul's avatar
Kashif Rasul committed
255
256
257
258
259
260
261
262
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
            self.norm2 = (
                AdaLayerNorm(dim, num_embeds_ada_norm)
                if self.use_ada_layer_norm
                else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
            )
263
264
265
266
        else:
            self.norm2 = None

        # 3. Feed-forward
Kashif Rasul's avatar
Kashif Rasul committed
267
        self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
Patrick von Platen's avatar
Patrick von Platen committed
268

269
270
271
272
273
274
275
    def forward(
        self,
        hidden_states,
        encoder_hidden_states=None,
        timestep=None,
        attention_mask=None,
        cross_attention_kwargs=None,
Kashif Rasul's avatar
Kashif Rasul committed
276
        class_labels=None,
277
    ):
Kashif Rasul's avatar
Kashif Rasul committed
278
279
280
281
282
283
284
285
286
        if self.use_ada_layer_norm:
            norm_hidden_states = self.norm1(hidden_states, timestep)
        elif self.use_ada_layer_norm_zero:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
            )
        else:
            norm_hidden_states = self.norm1(hidden_states)

Will Berman's avatar
Will Berman committed
287
        # 1. Self-Attention
288
289
290
291
292
293
294
        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
Kashif Rasul's avatar
Kashif Rasul committed
295
296
        if self.use_ada_layer_norm_zero:
            attn_output = gate_msa.unsqueeze(1) * attn_output
297
        hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
298

299
300
301
302
        if self.attn2 is not None:
            norm_hidden_states = (
                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
            )
Kashif Rasul's avatar
Kashif Rasul committed
303
304

            # 2. Cross-Attention
305
306
307
308
309
            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
310
            )
311
            hidden_states = attn_output + hidden_states
Will Berman's avatar
Will Berman committed
312
313

        # 3. Feed-forward
Kashif Rasul's avatar
Kashif Rasul committed
314
315
316
317
318
319
320
321
322
323
324
        norm_hidden_states = self.norm3(hidden_states)

        if self.use_ada_layer_norm_zero:
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

        ff_output = self.ff(norm_hidden_states)

        if self.use_ada_layer_norm_zero:
            ff_output = gate_mlp.unsqueeze(1) * ff_output

        hidden_states = ff_output + hidden_states
Will Berman's avatar
Will Berman committed
325

326
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
327
328
329


class FeedForward(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
330
331
332
333
    r"""
    A feed-forward layer.

    Parameters:
Will Berman's avatar
Will Berman committed
334
335
336
337
338
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
Kashif Rasul's avatar
Kashif Rasul committed
339
        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
Kashif Rasul's avatar
Kashif Rasul committed
340
341
342
    """

    def __init__(
Will Berman's avatar
Will Berman committed
343
344
345
346
347
348
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
Kashif Rasul's avatar
Kashif Rasul committed
349
        final_dropout: bool = False,
Kashif Rasul's avatar
Kashif Rasul committed
350
    ):
Patrick von Platen's avatar
Patrick von Platen committed
351
352
        super().__init__()
        inner_dim = int(dim * mult)
353
        dim_out = dim_out if dim_out is not None else dim
Patrick von Platen's avatar
Patrick von Platen committed
354

355
356
        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim)
Kashif Rasul's avatar
Kashif Rasul committed
357
358
        if activation_fn == "gelu-approximate":
            act_fn = GELU(dim, inner_dim, approximate="tanh")
359
360
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
361
        elif activation_fn == "geglu-approximate":
362
            act_fn = ApproximateGELU(dim, inner_dim)
Will Berman's avatar
Will Berman committed
363
364

        self.net = nn.ModuleList([])
365
        # project in
366
        self.net.append(act_fn)
367
368
369
370
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out))
Kashif Rasul's avatar
Kashif Rasul committed
371
372
373
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))
Patrick von Platen's avatar
Patrick von Platen committed
374

375
    def forward(self, hidden_states):
376
377
378
        for module in self.net:
            hidden_states = module(hidden_states)
        return hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
379

Patrick von Platen's avatar
Patrick von Platen committed
380

381
382
class GELU(nn.Module):
    r"""
Kashif Rasul's avatar
Kashif Rasul committed
383
    GELU activation function with tanh approximation support with `approximate="tanh"`.
384
385
    """

Kashif Rasul's avatar
Kashif Rasul committed
386
    def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
387
388
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)
Kashif Rasul's avatar
Kashif Rasul committed
389
        self.approximate = approximate
390
391
392

    def gelu(self, gate):
        if gate.device.type != "mps":
Kashif Rasul's avatar
Kashif Rasul committed
393
            return F.gelu(gate, approximate=self.approximate)
394
        # mps: gelu is not implemented for float16
Kashif Rasul's avatar
Kashif Rasul committed
395
        return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
396
397
398
399
400
401
402

    def forward(self, hidden_states):
        hidden_states = self.proj(hidden_states)
        hidden_states = self.gelu(hidden_states)
        return hidden_states


Patrick von Platen's avatar
Patrick von Platen committed
403
class GEGLU(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
404
405
406
407
    r"""
    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.

    Parameters:
Will Berman's avatar
Will Berman committed
408
409
        dim_in (`int`): The number of channels in the input.
        dim_out (`int`): The number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
410
411
412
    """

    def __init__(self, dim_in: int, dim_out: int):
Patrick von Platen's avatar
Patrick von Platen committed
413
414
415
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

416
417
418
419
420
421
    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

422
423
    def forward(self, hidden_states):
        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
424
        return hidden_states * self.gelu(gate)
Will Berman's avatar
Will Berman committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459


class ApproximateGELU(nn.Module):
    """
    The approximate form of Gaussian Error Linear Unit (GELU)

    For more details, see section 2: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.proj(x)
        return x * torch.sigmoid(1.702 * x)


class AdaLayerNorm(nn.Module):
    """
    Norm layer modified to incorporate timestep embeddings.
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim)
        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

    def forward(self, x, timestep):
        emb = self.linear(self.silu(self.emb(timestep)))
        scale, shift = torch.chunk(emb, 2)
        x = self.norm(x) * (1 + scale) + shift
        return x
Kashif Rasul's avatar
Kashif Rasul committed
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480


class AdaLayerNormZero(nn.Module):
    """
    Norm layer adaptive layer norm zero (adaLN-Zero).
    """

    def __init__(self, embedding_dim, num_embeddings):
        super().__init__()

        self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)

        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, timestep, class_labels, hidden_dtype=None):
        emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp