transformer.py 88.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
19
from flax.linen.attention import combine_masks
20
21
22
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
23
from jax.ad_checkpoint import checkpoint_name
24
25
26

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
27
from ..attention import AttnBiasType, AttnMaskType, QKVLayout
28
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
29
from ..attention import fused_attn
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
42
43
44
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
45
46
47
48
49
50
51
52
53
54
55
56
57
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
58
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
59
60
61
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
62
63
64
65
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
66
67
68
69

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
70
71
72
73
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

74
75
    Parameters
    ----------
76
    rules: Sequence[Tuple[str, Union[str, None]]]
77
78
79
80
        the base Flax logical axis rules to extend.

    Returns
    -------
81
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
82
83
84
85
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
86
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
87
88
        key = item[0]
        val = item[1]
89
90
91
92
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
93
94
95
96
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
97
98

    extended_rules = [*rules]
99
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
100
101
102
        key = item[0]
        val = item[1]
        if key in rules_map:
103
104
105
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
106
                f" rules, but {val} in TE's rules."
107
            )
108
109
110
111
112
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


113
114
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
115
116
117
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
118
    weight_dtype: DType = jnp.float32
119
120
121
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
122
    window_size: Optional[Tuple[int, int]] = None
123
124

    @nn.compact
125
126
127
128
129
130
131
132
133
134
135
136
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
137
        batch_dim = 1 if self.transpose_batch_sequence else 0
138
139
140
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
141
        sequence_dim = 0 if self.transpose_batch_sequence else 1
142
143
144
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
145
146
147
148
149
150
151
152

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
153
154
            query = query.astype(self.dtype)
            key = key.astype(self.dtype)
155
156
157
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
158
        is_gqa = h_q != h_kv
159

160
        if is_gqa:
161
162
163
164
165
166
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
167
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
168
            else:
169
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
170
        else:
171
            if is_gqa:
172
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
173
            else:
174
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
175

176
        attn_weights = checkpoint_name(attn_weights, "logits")
177

178
        if is_gqa:
179
180
181
182
183
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        attn_weights = with_sharding_constraint_by_logical_axes(
184
185
            attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
        )
186
187
188
189
190

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
191
            fused_scale_factor = 1.0
192
        else:
193
194
195
196
197
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

198
        def apply_swa_mask(original_mask: Array) -> Array:
199
            """Apply the sliding window mask to a given mask"""
200
            batch = original_mask.shape[0]
201
202
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
203
204
205
206
207
208
209
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
210
211
            return new_mask

212
213
        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
214
215
216
217
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
218
                mask = None
219
            if mask is not None:
220
                mask = apply_swa_mask(mask)
221
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
222
            if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
223
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
224
225
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
                if mask is not None:
226
227
                    return SoftmaxType.SCALED_MASKED, mask
                return SoftmaxType.SCALED, mask
228
229
230
231
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
232

233
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
234

235
236
237
        attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
            attn_weights, mask, bias
        ).astype(self.dtype)
238

239
        if is_gqa:
240
241
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

242
        if not deterministic and self.attention_dropout > 0.0:
243
244
245
246
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
247
            multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
248
249
250
251
            attn_weights = attn_weights * multiplier

        if self.transpose_batch_sequence:
            if is_gqa:
252
253
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
254
255

        if is_gqa:
256
257
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
258
259


260
261
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
262
263
264
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
265
    weight_dtype: DType = jnp.float32
266
267
268
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
269
    window_size: Optional[Tuple[int, int]] = None
270
271
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
272
273

    @nn.compact
274
275
276
277
278
279
280
281
282
283
284
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
285
286
287
288
289
290
291
292
293
294
295

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

296
        # TODO(rewang): integrate THD format
297
298
299
300
301
302
303
304
305
        if self.qkv_layout == QKVLayout.BS3HD:
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
306
307
            x = fused_attn(
                (qkv_packed,),
308
309
310
311
312
                bias,
                mask,
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
313
                qkv_layout=self.qkv_layout,
314
315
316
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
317
                window_size=self.window_size,
318
319
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
320
            )
321
322
323
324
325
326
327
328
329
330
        elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
331
332
            x = fused_attn(
                (query, kv_packed),
333
334
335
336
337
                bias,
                mask,
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
338
                qkv_layout=self.qkv_layout,
339
340
341
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
342
                window_size=self.window_size,
343
344
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
345
            )
346
347
348
349
350
        elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
351
            x = fused_attn(
352
                (query, key, value),
353
354
355
356
357
                bias,
                mask,
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
358
                qkv_layout=self.qkv_layout,
359
360
361
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
362
                window_size=self.window_size,
363
364
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
365
            )
366
367
368
369
370
371
372
373
374
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

        return x


375
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

396
397
398
399
400
401
402
403
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

440
    attn_bias_type: Optional[str], default = None
441
        Type of the attention bias passed in the attention.
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
477
478
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
479
480
481
    context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis (str): The name of the context parallel axis.
482
483
484

    Optimization parameters
    -----------------------
485
486
487
488
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
489
    """
490

491
492
493
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
494
495
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
496
497
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
498
    weight_dtype: DType = jnp.float32
499
    dropout_rng_name: str = "dropout"
500
    float32_logits: bool = False
501
    qkv_layout: str = "bshd_bshd_bshd"
502
503
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
504
    window_size: Optional[Tuple[int, int]] = None
505
506
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
507
508

    @nn.compact
509
510
511
512
513
514
515
516
517
518
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
    ) -> Array:
519
520
521
522
523
524
525
526
527
528
529
530
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
531
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """

        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]

568
569
570
571
572
573
574
575
576
577
578
579
        has_fused_attn_kernel = is_fused_attn_kernel_available(
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
            self.head_dim,
580
            self.window_size,
581
        )
582

583
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
584
585

        if enable_fused_attn and not has_fused_attn_kernel:
586
587
588
589
590
591
592
593
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n"
            )
594
595

        dropout_rng = None
596
        if not deterministic and self.attention_dropout > 0.0:
597
598
599
600
601
602
603
604
605
606
607
608
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(self.head_dim)
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(query, [1, 2], axis=-3)
609
610
611
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
612
613
614
615
616
617
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD

618
619
620
621
622
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
623
                weight_dtype=self.weight_dtype,
624
625
626
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
627
                window_size=self.window_size,
628
            )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
629
630
631
632
633
634
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
635
                weight_dtype=self.weight_dtype,
636
637
638
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
639
                window_size=self.window_size,
640
641
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
642
            )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
643

644
        return x
645
646


647
648
649
650
651
652
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
653
654
655
    """
    Rotary Positional Embedding
    x should be in shape of
656
657
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
658
    """
659
660
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
661
662
663
    min_window = windows[0]
    max_window = windows[1]

664
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
665
    time_scales = min_window * (max_window / min_window) ** fraction
666
667
668
669
670
671
672
673
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

674
675
676
677
678
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
679

680
681
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
682

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        x1, x2 = jnp.split(x, 2, axis=-1)
        part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(x.dtype)

        output = jnp.concatenate([part_1, part_2], axis=-1)
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
711
712
713
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
714
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
715
        )
716
717
718
719
720

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

721
    if group_method == "alternate":
722
723
        return alternate_impl()
    return consecutive_impl()
724
725


726
class LoRAScope:  # pylint: disable=too-few-public-methods
727
728
729
730
731
732
733
734
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
735
736
737
738
739
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
740
741
742
743


def _canonicalize_lora_scope(scope):

744
745
746
747
748
749
750
751
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
752
753
754
755
756
757

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
758
759
760
761
762
763
764
765
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


782
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
783
784
785
786
787
788
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
789
    head_dim: int
790
        The hidden dimension of each attention head.
791
792
793
794
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
795
796
797
798
799
800
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
801
802
803
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

823
824
825
826
827
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
828
    dropout_rng_name: str, default = 'dropout'
829
830
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
831
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
832
        Indicate the type of layer normalization.
833
    layernorm_epsilon: float, default = 1e-6
834
        A value added to the denominator of layer normalization for numerical stability.
835
    zero_centered_gamma: bool, default = False
836
837
838
839
840
841
842
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
843
844
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
845
        Used for initializing the QKV and output projection weights.
846
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
847
    use_bias: bool, default = False
848
        Indicate whether or not to enable bias shifting for QKV and output projections.
849
        If set to False, the layer will not learn additive biases.
850
    bias_init: Initializer, default = flax.linen.initializers.zeros
851
852
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
853
854
855
856
857
858
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
859
860
861
862
863
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
864
865
866
867
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
868
869
870
871
872
873
874
875
876
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
877
878
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
879
880
881
882
883
884
885
886
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
887
888
889

    Optimization parameters
    -----------------------
890
891
892
893
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
894
    fuse_qkv_params: bool, default = True
895
        If set to True, this module exposes a single fused
896
897
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
898
    transpose_batch_sequence: bool, default = True
899
        Indicate whether the input tensors were switched axis of batch
900
901
902
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
903
        Indicate whether to scale attention logits.
904
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
905
        else :math:`Q*K`
906
907
908
909
910
911
912
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
913
914
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
915
916
917
    """

    head_dim: int
918
919
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
920
921
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
922
    input_layernorm: bool = True
923
924
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
925
    return_layernorm_output: bool = False
926
    zero_centered_gamma: bool = False
927
928
929
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
930
    attn_mask_type: str = "causal"
931
    attn_bias_type: Optional[str] = None
932
933
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
934
935
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
936
937
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
938
    dtype: DType = jnp.float32
939
    weight_dtype: DType = jnp.float32
940
    fuse_qkv_params: bool = True
941
    transpose_batch_sequence: bool = True
942
    enable_sequence_parallel: bool = False
943
944
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
945
    float32_logits: bool = False
946
    window_size: Optional[Tuple[int, int]] = None
947
948
949
950
951
952
953

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
954
955

    def __post_init__(self):
956
957
958
959
960
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
961
962
963
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
964
965
966
967
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
968
969
970
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
971
972
973
974
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
975
976
                DeprecationWarning,
            )
977
978
979
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
980
981
982
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
983
984
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
985
986
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
987

988
        if self.kernel_init is None:
989
            self.kernel_init = nn.initializers.variance_scaling(
Phuong Nguyen's avatar
Phuong Nguyen committed
990
                1.0, "fan_in", "normal", dtype=self.weight_dtype
991
            )
zlsh80826's avatar
zlsh80826 committed
992
        if self.num_gqa_groups is None:
993
            self.num_gqa_groups = self.num_attention_heads
994
995
996
        super().__post_init__()

    @nn.compact
997
998
999
1000
1001
1002
1003
1004
1005
1006
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1007
1008
1009
1010
1011
1012
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1013
        inputs_q: jax.numpy.ndarray
1014
            Input tensor for query projection.
1015
        inputs_kv: jax.numpy.ndarray
1016
            Input tensor for key/value projection.
1017
1018
1019
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
1020
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
1021
1022
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1023
        *
1024
        decode: bool, default = False
1025
            Indicate whether to prepare and use an autoregressive cache.
1026
        deterministic: bool, default = False
1027
1028
1029
1030
            Disable dropout layers if set to True.

        Returns
        -------
1031
        outputs: jax.numpy.ndarray
1032
1033
            Output tensors.
        """
1034
1035

        def query_init(*args):
1036
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1079
1080
1081
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1082

1083
1084
1085
1086
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1087
1088
1089
1090
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1091
1092
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1093
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1094
            if is_qkvpack:
1095
                qkv_proj, ln_out = LayerNormDenseGeneral(
1096
                    enable_layernorm=self.input_layernorm,
1097
                    layernorm_type=self.layernorm_type,
1098
                    zero_centered_gamma=self.zero_centered_gamma,
1099
1100
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1101
                    features=(3, self.num_attention_heads * self.head_dim),
1102
                    transpose_batch_sequence=self.transpose_batch_sequence,
1103
                    return_layernorm_output=self.return_layernorm_output,
1104
1105
1106
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1107
1108
1109
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1110
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1111
1112
1113
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1114
1115
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1116
1117
                    name="qkv",
                    dtype=self.dtype,
1118
                    weight_dtype=self.weight_dtype,
1119
1120
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1121
                qkv_layout = QKVLayout.BS3HD
1122
1123
            else:
                query, ln_out = LayerNormDenseGeneral(
1124
                    enable_layernorm=self.input_layernorm,
1125
                    layernorm_type=self.layernorm_type,
1126
                    zero_centered_gamma=self.zero_centered_gamma,
1127
1128
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1129
                    features=self.num_attention_heads * self.head_dim,
1130
                    transpose_batch_sequence=self.transpose_batch_sequence,
1131
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1132
1133
1134
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1135
1136
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1137
                    bias_axes=(W_TP_AXES,),
1138
1139
1140
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1141
                    dtype=self.dtype,
1142
                    weight_dtype=self.weight_dtype,
1143
                    kernel_init=query_init,
1144
1145
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1146
1147
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1148
1149
1150
1151
1152

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
                    name="kv",
                    dtype=self.dtype,
1167
                    weight_dtype=self.weight_dtype,
1168
1169
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1170
                qkv_layout = QKVLayout.BSHD_BS2HD
1171
1172
1173
1174
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1175
                features=self.num_gqa_groups * self.head_dim,
1176
                transpose_batch_sequence=self.transpose_batch_sequence,
1177
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1178
1179
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1180
                bias_axes=(W_TP_AXES,),
1181
1182
1183
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1184
                dtype=self.dtype,
1185
                weight_dtype=self.weight_dtype,
1186
            )
1187
            query, ln_out = LayerNormDenseGeneral(
1188
                enable_layernorm=self.input_layernorm,
1189
                layernorm_type=self.layernorm_type,
1190
                zero_centered_gamma=self.zero_centered_gamma,
1191
1192
                epsilon=self.layernorm_epsilon,
                axis=-1,
1193
                features=self.num_attention_heads * self.head_dim,
1194
1195
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
1196
1197
1198
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1199
1200
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1201
                bias_axes=(W_TP_AXES,),
1202
1203
1204
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1205
                dtype=self.dtype,
1206
                weight_dtype=self.weight_dtype,
1207
                kernel_init=query_init,
1208
1209
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1210
1211
                name="query",
            )(inputs_q)
1212

1213
            if is_self_attn:
1214
1215
1216
                assert ln_out is not None
                inputs_kv = ln_out

1217
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1218
            key = key.astype(self.dtype)
1219
1220
1221
1222
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1223
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1224

1225
        if self.enable_rotary_pos_emb:
1226
1227
1228
1229
1230
1231
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1232

1233
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1234
1235
1236
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1249
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1250

1251
1252
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1253
1254
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1255
1256

        if decode:
1257
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1258
1259
1260
1261
1262
1263
1264
1265
1266
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1267
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1268
                if self.transpose_batch_sequence:
1269
1270
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1271
1272
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1273
1274
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1275
                    one_hot_indices_shape = (1, length, 1, 1)
1276
1277
1278
1279

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1280
1281
1282
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1283

1284
                cur_index = cache_index.value.astype(jnp.int32)
1285
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1286
1287
1288
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1289
1290
1291
1292
1293
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1294
1295
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1296
1297

                if bias is not None:
1298
1299
1300
1301
1302
1303
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1304

1305
1306
1307
1308
1309
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1310
1311
1312
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1324
        else:
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1335
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1336
1337
1338
1339
1340
1341
1342
1343
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
1344
            weight_dtype=self.weight_dtype,
1345
1346
1347
1348
1349
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1350
            window_size=self.window_size,
1351
        )(*dpa_args, mask, bias, deterministic=deterministic)
1352
1353
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1354
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1355
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1356

1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            transpose_batch_sequence=self.transpose_batch_sequence,
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
1370
            weight_dtype=self.weight_dtype,
1371
1372
1373
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1374
1375

        return out, ln_out
1376
1377


1378
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1379
1380
1381
1382
1383
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1384
    num_buckets: int
1385
        The number of buckets to bucket distances between key and query positions into.
1386
    max_distance: int
1387
        The maximum distance before everything is lumped into the last
1388
        distance bucket.
1389
    num_attention_heads: int
1390
        Number of attention heads in the transformer layer.
1391
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1392
        Used for initializing relative embedding tables.
1393
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1394
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1395
1396
1397

    Optimization parameters
    -----------------------
1398
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1399
1400
1401
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
1402
    """
1403

1404
1405
1406
1407
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1408
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1409
    dtype: DType = jnp.float32
1410
    weight_dtype: DType = jnp.float32
1411
1412
1413
1414
1415
1416
1417
1418

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1419
        q_seqlen: int
1420
            The sequence length of query.
1421
        k_seqlen: int
1422
            The sequence length of key.
1423
        bidirectional: bool, default = True
1424
            Indicate whether to allow positive memory-query relative position
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1451
1452
1453
1454
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1455
1456
1457
1458
1459
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
1460
1461
1462
            "rel_embedding",
            self.embedding_init,
            (self.num_attention_heads, self.num_buckets),
1463
            self.weight_dtype,
1464
1465
            axes=self.embedding_axes,
        )
1466
1467
1468
1469
1470
1471

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1472
1473
1474
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1475
1476
1477
1478
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1489

1490
1491
1492
1493
    ENCODER = "encoder"
    DECODER = "decoder"


1494
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1495
1496
1497
1498
1499
1500
1501
1502
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1503
        The hidden size of each input sample.
1504
    mlp_hidden_size: int, default = 2048
1505
        Intermediate size to which input samples are projected.
1506
    num_attention_heads: int, default = 8
1507
        Number of attention heads in the transformer layer.
1508
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1509
1510
1511
1512
1513
1514
1515
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1516
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1517
        Indicate the type of layer normalization.
1518
    layernorm_epsilon: float, default = 1e-6
1519
        A value added to the denominator of layer normalization for numerical stability.
1520
    zero_centered_gamma: bool, default = False
1521
1522
1523
1524
1525
1526
1527
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1528
    hidden_dropout: float, default = 0.1
1529
        Dropout probability for the dropout op after FC2 layer.
1530
    hidden_dropout_dims: Sequence[int], default = ()
1531
        Dimensions that will share the same dropout mask for hidden
1532
    attention_dropout: float, default = 0.1
1533
        Dropout probability for the dropout op during multi-head attention.
1534
1535
1536
1537
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1538
    dropout_rng_name: str, default = 'dropout'
1539
1540
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1541
1542
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1543
1544
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1545
1546
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1547
1548
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1549
    mlp_activations: Sequence[str], default = ('relu', )
1550
        The sequence of activation functions to apply after the first linear transformation.
1551
1552
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1553
1554
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1555
    bias_init: Initializer, default = flax.linen.initializers.zeros
1556
1557
1558
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1559
    apply_residual_connection_post_layernorm: bool, default = False
1560
        If set to True, residual connections are taken from the output
1561
1562
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1563
        If set to True, layer normalization is applied on the output side,
1564
1565
1566
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1567
1568
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1569
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1570
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1571
1572
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1573
    self_attn_mask_type: str, default = 'causal'
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1593
1594
1595
1596
1597
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1598
    enable_relative_embedding: bool, default = True
1599
        Whether to enable relative embedding as shifting of attention logits.
1600
    relative_embedding: flax.linen.Module, default = None
1601
        The module for relative embedding execution, only used when
1602
1603
1604
1605
1606
1607
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1608
1609
1610
1611
1612
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1613
    rotary_pos_emb_group_method: str, default = 'consecutive'
1614
1615
1616
1617
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1618
1619
1620
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1621
        'exclude_output_proj', 'exclude_mlp']
1622
1623
1624
1625
1626
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1627
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1628
1629
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1630
1631
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1632
1633
1634

    Optimization parameters
    -----------------------
1635
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1636
1637
1638
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
1639
    drop_path: float, default = 0.0
1640
        When > 0.0, applies stochastic depth per sample in the main
1641
1642
        path of the residual block.
    fuse_qkv_params: bool, default = True
1643
        If set to True, `TransformerLayer` module exposes a single fused
1644
1645
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1646
    transpose_batch_sequence: bool, default = False
1647
        Indicate whether the input tensors were switched axis of batch
1648
1649
1650
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1651
        Indicate whether to scale attention logits.
1652
1653
1654
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1655
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1656
1657
1658
1659
1660
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1661
    num_gqa_groups: Optional[int] = None
1662
    layernorm_type: str = "layernorm"
1663
    layernorm_epsilon: float = 1e-6
1664
    zero_centered_gamma: bool = False
1665
1666
1667
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1668
1669
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1670
    dropout_rng_name: str = "dropout"
1671
1672
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1673
    mlp_activations: Sequence[str] = ("relu",)
1674
1675
1676
1677
1678
1679
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1680
    self_attn_mask_type: str = "causal"
1681
    self_attn_bias_type: Optional[str] = None
1682
1683
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1684
1685
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1686
1687
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1688
1689
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1690
    dtype: DType = jnp.float32
1691
    weight_dtype: DType = jnp.float32
1692
1693
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1694
    transpose_batch_sequence: bool = False
1695
    enable_sequence_parallel: bool = False
1696
1697
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1698
    window_size: Optional[Tuple[int, int]] = None
1699
1700
1701

    def __post_init__(self):
        if self.mha_kernel_init is None:
1702
            self.mha_kernel_init = nn.initializers.variance_scaling(
1703
                1.0, "fan_in", "normal", dtype=self.weight_dtype
1704
            )
1705
        if self.mlp_kernel_init is None:
1706
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1707
                1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
1708
            )
zlsh80826's avatar
zlsh80826 committed
1709
1710
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1711
1712
1713
        super().__post_init__()

    @nn.compact
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1724
1725
1726
1727
1728
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1729
        inputs: jax.numpy.ndarray
1730
            Input tensor.
1731
        encoded: jax.numpy.ndarray, default = None
1732
1733
1734
1735
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1736
1737
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1738
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1739
1740
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1741
            :attr:`True` means mask out the corresponding values.
1742
        deterministic: bool, default = False
1743
            Disable dropout layers if set to True.
1744
        decode: bool, default = False
1745
1746
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1747
        max_decode_length: bool, default = None
1748
1749
1750
1751
1752
1753
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1754
        outputs: jax.numpy.ndarray
1755
            Output tensors.
1756
        """
1757
1758
1759

        inputs = inputs.astype(self.dtype)

1760
1761
1762
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1763

1764
1765
1766
1767
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1768

1769
1770
1771
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1772
1773
1774
1775
1776
1777

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1778
1779
1780
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1781
1782
1783
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1784
1785
1786
1787
1788

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1789
1790
1791
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1792
1793
1794
1795
1796
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
1797
                    weight_dtype=self.weight_dtype,
1798
1799
1800
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
1818
            mha_name = "attention"
1819
        else:
1820
            mha_name = "self_attention"
1821

1822
        inputs = with_sharding_constraint_by_logical_axes(
1823
1824
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1825

1826
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1827
1828
1829
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1830
            dtype=self.dtype,
1831
            weight_dtype=self.weight_dtype,
1832
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1833
            num_gqa_groups=self.num_gqa_groups,
1834
            transpose_batch_sequence=self.transpose_batch_sequence,
1835
            enable_sequence_parallel=self.enable_sequence_parallel,
1836
            attention_dropout=self.attention_dropout,
1837
1838
1839
1840
1841
1842
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1843
            zero_centered_gamma=self.zero_centered_gamma,
1844
1845
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1846
            attn_mask_type=self.self_attn_mask_type,
1847
            attn_bias_type=self.self_attn_bias_type,
1848
1849
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1850
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1851
1852
1853
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1854
            fuse_qkv_params=self.fuse_qkv_params,
1855
1856
1857
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1858
            name=mha_name,
1859
            window_size=self.window_size,
1860
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
1861
1862
1863
1864
1865

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1866
                assert -x_shape_len <= dims < x_shape_len
1867

1868
1869
1870
1871
1872
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1873

1874
        x = with_sharding_constraint_by_logical_axes(
1875
1876
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1877
        residual = with_sharding_constraint_by_logical_axes(
1878
1879
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1880

1881
1882
1883
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1884
1885
1886
1887
1888
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1889
1890
1891
1892
1893

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1894
1895
1896
1897
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
1898
1899
1900
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
1901

1902
            x = with_sharding_constraint_by_logical_axes(
1903
1904
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1905

1906
1907
1908
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1909
                dtype=self.dtype,
1910
                weight_dtype=self.weight_dtype,
1911
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1912
                num_gqa_groups=self.num_gqa_groups,
1913
                transpose_batch_sequence=self.transpose_batch_sequence,
1914
                enable_sequence_parallel=self.enable_sequence_parallel,
1915
                attention_dropout=self.attention_dropout,
1916
1917
1918
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1919
                zero_centered_gamma=self.zero_centered_gamma,
1920
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
1921
1922
1923
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
1924
1925
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1926
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1927
1928
1929
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1930
1931
1932
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
1933
                fuse_qkv_params=self.fuse_qkv_params,
1934
1935
1936
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1937
                name="encoder_decoder_attention",
1938
                window_size=self.window_size,
1939
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
1940
1941

            y = with_sharding_constraint_by_logical_axes(
1942
1943
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1944
            residual = with_sharding_constraint_by_logical_axes(
1945
1946
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1947

1948
            y = hidden_dropout(y, deterministic)
1949
1950
1951
1952
1953

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

1954
1955
            mlp_input = y + residual

1956
        mlp_input = with_sharding_constraint_by_logical_axes(
1957
1958
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1959

1960
1961
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1962
1963
1964
1965
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
1966
            zero_centered_gamma=self.zero_centered_gamma,
1967
1968
1969
1970
1971
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
1972
1973
1974
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
1975
            dtype=self.dtype,
1976
            weight_dtype=self.weight_dtype,
1977
1978
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
1979
            kernel_init=self.mlp_kernel_init,
1980
1981
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
1982
1983
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1984
1985
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
1986
1987
1988
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1989
1990
1991
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
1992
            name="mlp",
1993
1994
1995
1996
1997
1998
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1999
        z = with_sharding_constraint_by_logical_axes(
2000
2001
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2002
        residual = with_sharding_constraint_by_logical_axes(
2003
2004
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2005

2006
2007
2008
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2009
2010
2011
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2012
2013
2014
        z = z + residual

        if self.output_layernorm:
2015
            z = with_sharding_constraint_by_logical_axes(
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                transpose_batch_sequence=self.transpose_batch_sequence,
                dtype=self.dtype,
2026
                weight_dtype=self.weight_dtype,
2027
2028
                name="output_layernorm",
            )(z)
2029
2030

        return z