transformer.py 92.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
18
from flax.linen.attention import combine_masks
19
20
21
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
27
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
28
from ..attention import fused_attn
29
from ..attention import CPStrategy
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
42
43
44
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
45
46
47
48
49
50
51
52
53
54
55
56
57
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
58
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
59
60
61
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
62
63
64
65
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
66
67
68
69

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
70
71
72
73
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

74
75
    Parameters
    ----------
76
    rules: Sequence[Tuple[str, Union[str, None]]]
77
78
79
80
        the base Flax logical axis rules to extend.

    Returns
    -------
81
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
82
83
84
85
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
86
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
87
88
        key = item[0]
        val = item[1]
89
90
91
92
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
93
94
95
96
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
97
98

    extended_rules = [*rules]
99
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
100
101
102
        key = item[0]
        val = item[1]
        if key in rules_map:
103
104
105
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
106
                f" rules, but {val} in TE's rules."
107
            )
108
109
110
111
112
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


113
114
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
115
116
117
118
119
120
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
121
    window_size: Optional[Tuple[int, int]] = None
122
123

    @nn.compact
124
125
126
127
128
129
130
131
132
133
134
135
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
136
        batch_dim = 1 if self.transpose_batch_sequence else 0
137
138
139
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
140
        sequence_dim = 0 if self.transpose_batch_sequence else 1
141
142
143
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
144

145
146
        input_dtype = query.dtype

147
148
149
150
151
152
153
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
154
155
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
156
157
158
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
159
        is_gqa = h_q != h_kv
160

161
        if is_gqa:
162
163
164
165
166
167
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
168
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
169
            else:
170
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
171
        else:
172
            if is_gqa:
173
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
174
            else:
175
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
176

177
        attn_weights = checkpoint_name(attn_weights, "logits")
178

179
        if is_gqa:
180
181
182
183
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

184
        # (b, h, q, k): Last two axes are always replicated
185
        attn_weights = with_sharding_constraint_by_logical_axes(
186
            attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
187
        )
188
189
190
191
192

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
193
            fused_scale_factor = 1.0
194
        else:
195
196
197
198
199
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

200
        def apply_swa_mask(original_mask: Array) -> Array:
201
            """Apply the sliding window mask to a given mask"""
202
            batch = original_mask.shape[0]
203
204
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
205
206
207
208
209
210
211
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
212
213
            return new_mask

214
215
        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
216
217
218
219
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
220
                mask = None
221
            if mask is not None:
222
                mask = apply_swa_mask(mask)
223
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
224
225
226
            if mask is not None:
                return SoftmaxType.SCALED_MASKED, mask
            if attn_mask_type is AttnMaskType.CAUSAL_MASK:
227
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
228
            if attn_mask_type is AttnMaskType.NO_MASK:
229
                return SoftmaxType.SCALED, mask
230
231
232
233
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
234

235
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
236

237
238
        attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
            attn_weights, mask, bias
239
        ).astype(input_dtype)
240

241
        if is_gqa:
242
243
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

244
        if not deterministic and self.attention_dropout > 0.0:
245
246
247
248
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
249
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
250
251
            attn_weights = attn_weights * multiplier

252
253
254
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
255
256
        if self.transpose_batch_sequence:
            if is_gqa:
257
258
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
259
260

        if is_gqa:
261
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
262

263
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
264
265


266
267
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
268
269
270
271
272
273
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
274
    window_size: Optional[Tuple[int, int]] = None
275
    max_segments_per_seq: Optional[int] = 1
276
277
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
278
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
279
    context_checkpoint_name: str = "context"
280
281

    @nn.compact
282
283
284
285
286
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
287
        sequence_descriptor: Optional[SequenceDescriptor] = None,
288
289
290
291
292
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
293
294
295
296
297
298
299
300
301
302
303

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

304
        if self.qkv_layout.is_qkvpacked():
305
306
307
308
309
310
311
312
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
313
314
            x = fused_attn(
                (qkv_packed,),
315
                bias,
316
                sequence_descriptor,
317
318
319
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
320
                qkv_layout=self.qkv_layout,
321
322
323
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
324
                window_size=self.window_size,
325
                max_segments_per_seq=self.max_segments_per_seq,
326
327
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
328
                context_parallel_strategy=self.context_parallel_strategy,
329
                context_checkpoint_name=self.context_checkpoint_name,
330
            )
331
        elif self.qkv_layout.is_kvpacked():
332
333
334
335
336
337
338
339
340
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
341
342
            x = fused_attn(
                (query, kv_packed),
343
                bias,
344
                sequence_descriptor,
345
346
347
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
348
                qkv_layout=self.qkv_layout,
349
350
351
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
352
                window_size=self.window_size,
353
                max_segments_per_seq=self.max_segments_per_seq,
354
355
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
356
                context_parallel_strategy=self.context_parallel_strategy,
357
                context_checkpoint_name=self.context_checkpoint_name,
358
            )
359
        elif self.qkv_layout.is_separate():
360
361
362
363
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
364
            x = fused_attn(
365
                (query, key, value),
366
                bias,
367
                sequence_descriptor,
368
369
370
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
371
                qkv_layout=self.qkv_layout,
372
373
374
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
375
                window_size=self.window_size,
376
                max_segments_per_seq=self.max_segments_per_seq,
377
378
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
379
                context_parallel_strategy=self.context_parallel_strategy,
380
                context_checkpoint_name=self.context_checkpoint_name,
381
            )
382
383
384
385
386
387
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

388
        assert x.dtype == query.dtype
389
390
391
        return x


392
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

413
414
415
416
417
418
419
420
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

457
458
        .. note:: THD format only supports 'padding' or 'causal_padding' mask type.

459
460
461
462
463
464
465
466
       attn_mask_type       mask/sequence_descriptor       SWA          softmax type
       --------------------------------------------------------------------------------------------
       no_mask              None                           None         SCALED
       causal               None                           None         SCALED_UPPER_TRIANG_MASKED
       causal               None                           Yes          SCALED_MASKED
       padding              Required                       Yes/No       SCALED_MASKED
       padding_causal       Required                       Yes/No       SCALED_MASKED

467
    attn_bias_type: Optional[str], default = None
468
        Type of the attention bias passed in the attention.
469
470
471
472
473
474
475
476
477
478
479
480
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
481
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
482
483
484
485
486
487

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
488
489
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
506
507
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
508
509
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
510
511
512
    context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis (str): The name of the context parallel axis.
513
    context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
514
    context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
515
516
517

    Optimization parameters
    -----------------------
518
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
519
        The data type used to allocate the initial parameters.
520
    """
521

522
523
524
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
525
526
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
527
528
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
529
    dropout_rng_name: str = "dropout"
530
    float32_logits: bool = False
531
    qkv_layout: str = "bshd_bshd_bshd"
532
533
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
534
    window_size: Optional[Tuple[int, int]] = None
535
    max_segments_per_seq: Optional[int] = 1
536
537
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
538
    context_parallel_strategy: str = "DEFAULT"
539
    context_checkpoint_name: str = "context"
540
541

    @nn.compact
542
543
544
545
546
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
547
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
548
549
550
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
551
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
552
    ) -> Array:
553
554
555
556
557
558
559
560
561
562
563
564
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
565
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
566
567
568
569
570
571
572
573
574
575
576
577
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
578
        input_dtype = query.dtype
579

580
581
582
583
584
585
586
587
588
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]
611
612
613
614
615
616
        if qkv_layout.is_separate():
            head_dim_qk = query.shape[-1]
            head_dim_v = value.shape[-1]
        else:
            head_dim_qk = self.head_dim
            head_dim_v = self.head_dim
617

618
        has_fused_attn_kernel = is_fused_attn_kernel_available(
619
620
            # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
            not deterministic,
621
622
623
624
625
626
627
628
629
630
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
631
632
            head_dim_qk,
            head_dim_v,
633
            self.window_size,
634
        )
635

636
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
637
638

        if enable_fused_attn and not has_fused_attn_kernel:
639
640
641
642
643
644
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
645
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
646
            )
647
648

        dropout_rng = None
649
        if not deterministic and self.attention_dropout > 0.0:
650
651
652
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
653
            scale_factor = 1.0 / sqrt(head_dim_qk)
654
655
656
657
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        # case-insensitive mapping for context parallel strategy
        cp_strategy_map = {
            "DEFAULT": CPStrategy.DEFAULT,
            "ALL_GATHER": CPStrategy.ALL_GATHER,
            "ALLGATHER": CPStrategy.ALL_GATHER,  # Alternative spelling
            "RING": CPStrategy.RING,
        }

        strategy_key = self.context_parallel_strategy.upper()
        if strategy_key in cp_strategy_map:
            context_parallel_strategy = cp_strategy_map[strategy_key]
        else:
            valid_strategies = list(cp_strategy_map.keys())
            raise ValueError(
                f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
                f"Valid options are: {valid_strategies} (case insensitive)"
            )

676
677
        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
678
            if qkv_layout.is_qkvpacked():
679
                query, key, value = jnp.split(query, [1, 2], axis=-3)
680
681
682
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
683
            elif qkv_layout.is_kvpacked():
684
685
686
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
687
688
                assert qkv_layout.is_separate()

689
690
691
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
692

693
694
695
696
697
698
699
700
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
701
                window_size=self.window_size,
702
703
704
705
706
707
708
709
710
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
711
712
713
714
715
716
717
718
719
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
720
                window_size=self.window_size,
721
                max_segments_per_seq=self.max_segments_per_seq,
722
723
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
724
                context_parallel_strategy=context_parallel_strategy,
725
                context_checkpoint_name=self.context_checkpoint_name,
726
727
728
729
730
731
732
733
734
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
735
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
736
        return x
737
738


739
740
741
742
743
744
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
745
746
747
    """
    Rotary Positional Embedding
    x should be in shape of
748
749
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
750
    """
751
752
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
753
754
755
    min_window = windows[0]
    max_window = windows[1]

756
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
757
    time_scales = min_window * (max_window / min_window) ** fraction
758
759
760
761
762
763
764
765
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

766
767
768
769
770
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
771

772
773
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
774

775
        x1, x2 = jnp.split(x, 2, axis=-1)
776
777
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
778

779
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
803
804
805
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
806
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
807
        )
808
809
810
811
812

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

813
    if group_method == "alternate":
814
815
        return alternate_impl()
    return consecutive_impl()
816
817


818
class LoRAScope:  # pylint: disable=too-few-public-methods
819
820
821
822
823
824
825
826
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
827
828
829
830
831
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
832
833
834
835


def _canonicalize_lora_scope(scope):

836
837
838
839
840
841
842
843
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
844
845
846
847
848
849

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
850
851
852
853
854
855
856
857
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


874
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
875
876
877
878
879
880
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
881
    head_dim: int
882
        The hidden dimension of each attention head.
883
884
885
886
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
887
888
889
890
891
892
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
893
894
895
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

915
916
917
918
919
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
920
    dropout_rng_name: str, default = 'dropout'
921
922
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
923
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
924
        Indicate the type of layer normalization.
925
    layernorm_epsilon: float, default = 1e-6
926
        A value added to the denominator of layer normalization for numerical stability.
927
    zero_centered_gamma: bool, default = False
928
929
930
931
932
933
934
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
935
936
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
937
        Used for initializing the QKV and output projection weights.
938
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
939
    use_bias: bool, default = False
940
        Indicate whether or not to enable bias shifting for QKV and output projections.
941
        If set to False, the layer will not learn additive biases.
942
    bias_init: Initializer, default = flax.linen.initializers.zeros
943
944
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
945
946
947
948
949
950
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
951
952
953
954
955
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
956
957
958
959
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
960
961
962
963
964
965
966
967
968
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
969
970
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
971
972
973
974
975
976
977
978
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
979
980
981

    Optimization parameters
    -----------------------
982
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
983
        The data type used to allocate the initial parameters.
984
    fuse_qkv_params: bool, default = True
985
        If set to True, this module exposes a single fused
986
987
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
988
    transpose_batch_sequence: bool, default = True
989
        Indicate whether the input tensors were switched axis of batch
990
991
992
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
993
        Indicate whether to scale attention logits.
994
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
995
        else :math:`Q*K`
996
997
998
999
1000
1001
1002
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
1003
1004
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1005
1006
1007
    """

    head_dim: int
1008
1009
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
1010
1011
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
1012
    input_layernorm: bool = True
1013
1014
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
1015
    return_layernorm_output: bool = False
1016
    zero_centered_gamma: bool = False
1017
1018
1019
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
1020
    attn_mask_type: str = "causal"
1021
    attn_bias_type: Optional[str] = None
1022
1023
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1024
1025
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1026
1027
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1028
    dtype: DType = jnp.float32
1029
    fuse_qkv_params: bool = True
1030
    transpose_batch_sequence: bool = True
1031
    enable_sequence_parallel: bool = False
1032
1033
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1034
    float32_logits: bool = False
1035
    window_size: Optional[Tuple[int, int]] = None
1036
1037
1038
1039
1040
1041
1042

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
1043
1044

    def __post_init__(self):
1045
1046
1047
1048
1049
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1050
1051
1052
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1053
1054
1055
1056
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1057
1058
1059
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1060
1061
1062
1063
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1064
1065
                DeprecationWarning,
            )
1066
1067
1068
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1069
1070
1071
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1072
1073
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1074
1075
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1076

1077
        if self.kernel_init is None:
1078
            self.kernel_init = nn.initializers.variance_scaling(
1079
                1.0, "fan_in", "normal", dtype=self.dtype
1080
            )
zlsh80826's avatar
zlsh80826 committed
1081
        if self.num_gqa_groups is None:
1082
            self.num_gqa_groups = self.num_attention_heads
1083
1084
1085
        super().__post_init__()

    @nn.compact
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1096
1097
1098
1099
1100
1101
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1102
        inputs_q: jax.numpy.ndarray
1103
            Input tensor for query projection.
1104
        inputs_kv: jax.numpy.ndarray
1105
            Input tensor for key/value projection.
1106
1107
1108
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
1109
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
1110
1111
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1112
        *
1113
        decode: bool, default = False
1114
            Indicate whether to prepare and use an autoregressive cache.
1115
        deterministic: bool, default = False
1116
1117
1118
1119
            Disable dropout layers if set to True.

        Returns
        -------
1120
        outputs: jax.numpy.ndarray
1121
1122
            Output tensors.
        """
1123

1124
1125
1126
1127
1128
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1129
        def query_init(*args):
1130
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1173
1174
1175
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1176

1177
1178
1179
1180
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1181
1182
1183
1184
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1185
1186
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1187
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1188
            if is_qkvpack:
1189
                qkv_proj, ln_out = LayerNormDenseGeneral(
1190
                    enable_layernorm=self.input_layernorm,
1191
                    layernorm_type=self.layernorm_type,
1192
                    zero_centered_gamma=self.zero_centered_gamma,
1193
1194
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1195
1196
                    features=(3, self.num_attention_heads * self.head_dim),
                    return_layernorm_output=self.return_layernorm_output,
1197
1198
1199
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1200
1201
1202
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1203
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1204
1205
1206
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1207
1208
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1209
1210
1211
1212
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1213
                qkv_layout = QKVLayout.BS3HD
1214
1215
            else:
                query, ln_out = LayerNormDenseGeneral(
1216
                    enable_layernorm=self.input_layernorm,
1217
                    layernorm_type=self.layernorm_type,
1218
                    zero_centered_gamma=self.zero_centered_gamma,
1219
1220
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1221
1222
                    features=self.num_attention_heads * self.head_dim,
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1223
1224
1225
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1226
1227
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1228
                    bias_axes=(W_TP_AXES,),
1229
1230
1231
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1232
1233
                    dtype=self.dtype,
                    kernel_init=query_init,
1234
1235
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1236
1237
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1238
1239
1240
1241
1242

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1258
                qkv_layout = QKVLayout.BSHD_BS2HD
1259
1260
1261
1262
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1263
                features=self.num_gqa_groups * self.head_dim,
1264
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1265
1266
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1267
                bias_axes=(W_TP_AXES,),
1268
1269
1270
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1271
1272
                dtype=self.dtype,
            )
1273
            query, ln_out = LayerNormDenseGeneral(
1274
                enable_layernorm=self.input_layernorm,
1275
                layernorm_type=self.layernorm_type,
1276
                zero_centered_gamma=self.zero_centered_gamma,
1277
1278
                epsilon=self.layernorm_epsilon,
                axis=-1,
1279
                features=self.num_attention_heads * self.head_dim,
1280
                return_layernorm_output=True,
1281
1282
1283
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1284
1285
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1286
                bias_axes=(W_TP_AXES,),
1287
1288
1289
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1290
1291
                dtype=self.dtype,
                kernel_init=query_init,
1292
1293
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1294
1295
                name="query",
            )(inputs_q)
1296

1297
            if is_self_attn:
1298
1299
1300
                assert ln_out is not None
                inputs_kv = ln_out

1301
            query = query.astype(input_dtype)
1302
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1303
            key = key.astype(input_dtype)
1304
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1305
            value = value.astype(input_dtype)
1306
1307
1308
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1309
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1310

1311
        if self.enable_rotary_pos_emb:
1312
1313
1314
1315
1316
1317
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1318

1319
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1320
1321
1322
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1335
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1336

1337
1338
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1339
1340
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1341
1342

        if decode:
1343
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1344
1345
1346
1347
1348
1349
1350
1351
1352
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1353
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1354
                if self.transpose_batch_sequence:
1355
1356
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1357
1358
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1359
1360
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1361
                    one_hot_indices_shape = (1, length, 1, 1)
1362
1363
1364
1365

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1366
1367
1368
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1369

1370
                cur_index = cache_index.value.astype(jnp.int32)
1371
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1372
1373
1374
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1375
1376
1377
1378
1379
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1380
1381
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1382
1383

                if bias is not None:
1384
1385
1386
1387
1388
1389
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1390

1391
1392
1393
1394
1395
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1396
1397
1398
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1410
        else:
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1421
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1435
            window_size=self.window_size,
1436
        )(*dpa_args, mask, bias, deterministic=deterministic)
1437
1438
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1439
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1440
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1441

1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1457

1458
1459
1460
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1461
        return out, ln_out
1462
1463


1464
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1465
1466
1467
1468
1469
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1470
    num_buckets: int
1471
        The number of buckets to bucket distances between key and query positions into.
1472
    max_distance: int
1473
        The maximum distance before everything is lumped into the last
1474
        distance bucket.
1475
    num_attention_heads: int
1476
        Number of attention heads in the transformer layer.
1477
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1478
        Used for initializing relative embedding tables.
1479
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1480
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1481
1482
1483

    Optimization parameters
    -----------------------
1484
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1485
        The data type used to allocate the initial parameters.
1486
    """
1487

1488
1489
1490
1491
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1492
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1493
1494
1495
1496
1497
1498
1499
1500
1501
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1502
        q_seqlen: int
1503
            The sequence length of query.
1504
        k_seqlen: int
1505
            The sequence length of key.
1506
        bidirectional: bool, default = True
1507
            Indicate whether to allow positive memory-query relative position
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1534
1535
1536
1537
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1538
1539
1540
1541
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
1542
        relative_attention_bias = self.param(
1543
            "rel_embedding",
1544
            nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
1545
            (self.num_attention_heads, self.num_buckets),
1546
            self.dtype,
1547
        )
1548
1549
1550
1551
1552
1553

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1554
1555
1556
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1557
1558
1559
1560
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1571

1572
1573
1574
1575
    ENCODER = "encoder"
    DECODER = "decoder"


1576
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1577
1578
1579
1580
1581
1582
1583
1584
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1585
        The hidden size of each input sample.
1586
    mlp_hidden_size: int, default = 2048
1587
        Intermediate size to which input samples are projected.
1588
    num_attention_heads: int, default = 8
1589
        Number of attention heads in the transformer layer.
1590
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1591
1592
1593
1594
1595
1596
1597
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1598
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1599
        Indicate the type of layer normalization.
1600
    layernorm_epsilon: float, default = 1e-6
1601
        A value added to the denominator of layer normalization for numerical stability.
1602
    zero_centered_gamma: bool, default = False
1603
1604
1605
1606
1607
1608
1609
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1610
    hidden_dropout: float, default = 0.1
1611
        Dropout probability for the dropout op after FC2 layer.
1612
    hidden_dropout_dims: Sequence[int], default = ()
1613
        Dimensions that will share the same dropout mask for hidden
1614
    attention_dropout: float, default = 0.1
1615
        Dropout probability for the dropout op during multi-head attention.
1616
1617
1618
1619
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1620
    dropout_rng_name: str, default = 'dropout'
1621
1622
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1623
1624
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1625
1626
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1627
1628
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1629
1630
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1631
    mlp_activations: Sequence[str], default = ('relu', )
1632
        The sequence of activation functions to apply after the first linear transformation.
1633
1634
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1635
1636
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1637
    bias_init: Initializer, default = flax.linen.initializers.zeros
1638
1639
1640
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1641
    apply_residual_connection_post_layernorm: bool, default = False
1642
        If set to True, residual connections are taken from the output
1643
1644
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1645
        If set to True, layer normalization is applied on the output side,
1646
1647
1648
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1649
1650
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1651
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1652
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1653
1654
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1655
    self_attn_mask_type: str, default = 'causal'
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1675
1676
1677
1678
1679
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1680
    enable_relative_embedding: bool, default = True
1681
        Whether to enable relative embedding as shifting of attention logits.
1682
    relative_embedding: flax.linen.Module, default = None
1683
        The module for relative embedding execution, only used when
1684
1685
1686
1687
1688
1689
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1690
1691
1692
1693
1694
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1695
    rotary_pos_emb_group_method: str, default = 'consecutive'
1696
1697
1698
1699
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1700
1701
1702
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1703
        'exclude_output_proj', 'exclude_mlp']
1704
1705
1706
1707
1708
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1709
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1710
1711
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1712
1713
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1714
1715
1716

    Optimization parameters
    -----------------------
1717
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1718
        The data type used to allocate the initial parameters.
1719
    drop_path: float, default = 0.0
1720
        When > 0.0, applies stochastic depth per sample in the main
1721
1722
        path of the residual block.
    fuse_qkv_params: bool, default = True
1723
        If set to True, `TransformerLayer` module exposes a single fused
1724
1725
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1726
    transpose_batch_sequence: bool, default = False
1727
        Indicate whether the input tensors were switched axis of batch
1728
1729
1730
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1731
        Indicate whether to scale attention logits.
1732
1733
1734
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1735
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1736
1737
1738
1739
1740
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1741
    num_gqa_groups: Optional[int] = None
1742
    layernorm_type: str = "layernorm"
1743
    layernorm_epsilon: float = 1e-6
1744
    zero_centered_gamma: bool = False
1745
1746
1747
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1748
1749
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1750
    dropout_rng_name: str = "dropout"
1751
1752
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1753
    mlp_activations: Sequence[str] = ("relu",)
1754
1755
1756
1757
1758
1759
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1760
    self_attn_mask_type: str = "causal"
1761
    self_attn_bias_type: Optional[str] = None
1762
1763
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1764
1765
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1766
1767
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1768
1769
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1770
1771
1772
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1773
    transpose_batch_sequence: bool = False
1774
    enable_sequence_parallel: bool = False
1775
1776
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1777
    window_size: Optional[Tuple[int, int]] = None
1778
1779
1780

    def __post_init__(self):
        if self.mha_kernel_init is None:
1781
            self.mha_kernel_init = nn.initializers.variance_scaling(
1782
                1.0, "fan_in", "normal", dtype=self.dtype
1783
            )
1784
        if self.mlp_kernel_init is None:
1785
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1786
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1787
            )
zlsh80826's avatar
zlsh80826 committed
1788
1789
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1790
1791
1792
        super().__post_init__()

    @nn.compact
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1803
1804
1805
1806
1807
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1808
        inputs: jax.numpy.ndarray
1809
            Input tensor.
1810
        encoded: jax.numpy.ndarray, default = None
1811
1812
1813
1814
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1815
1816
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1817
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1818
1819
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1820
            :attr:`True` means mask out the corresponding values.
1821
        deterministic: bool, default = False
1822
            Disable dropout layers if set to True.
1823
        decode: bool, default = False
1824
1825
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1826
        max_decode_length: bool, default = None
1827
1828
1829
1830
1831
1832
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1833
        outputs: jax.numpy.ndarray
1834
            Output tensors.
1835
        """
1836

1837
        input_dtype = inputs.dtype
1838
1839
1840
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1841

1842
1843
1844
1845
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1846

1847
1848
1849
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1850
1851
1852
1853
1854
1855

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1856
1857
1858
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1859
1860
1861
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1862
1863
1864
1865
1866

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1867
1868
1869
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1870
1871
1872
1873
1874
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
1875
1876
1877
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
1878
1879
                    name="relpos_bias",
                )
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
1897
            mha_name = "attention"
1898
        else:
1899
            mha_name = "self_attention"
1900

1901
        inputs = with_sharding_constraint_by_logical_axes(
1902
1903
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1904

1905
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1906
1907
1908
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1909
1910
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1911
            num_gqa_groups=self.num_gqa_groups,
1912
            transpose_batch_sequence=self.transpose_batch_sequence,
1913
            enable_sequence_parallel=self.enable_sequence_parallel,
1914
            attention_dropout=self.attention_dropout,
1915
1916
1917
1918
1919
1920
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1921
            zero_centered_gamma=self.zero_centered_gamma,
1922
1923
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1924
            attn_mask_type=self.self_attn_mask_type,
1925
            attn_bias_type=self.self_attn_bias_type,
1926
1927
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1928
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1929
1930
1931
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1932
            fuse_qkv_params=self.fuse_qkv_params,
1933
1934
1935
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1936
            name=mha_name,
1937
            window_size=self.window_size,
1938
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
1939
1940
1941
1942
1943

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1944
                assert -x_shape_len <= dims < x_shape_len
1945

1946
1947
1948
1949
1950
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1951

1952
        x = with_sharding_constraint_by_logical_axes(
1953
1954
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1955
        residual = with_sharding_constraint_by_logical_axes(
1956
1957
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1958

1959
1960
1961
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1962
1963
1964
1965
1966
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1967
1968
1969
1970
1971

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1972
1973
1974
1975
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
1976
1977
1978
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
1979

1980
            x = with_sharding_constraint_by_logical_axes(
1981
1982
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1983

1984
1985
1986
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1987
1988
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1989
                num_gqa_groups=self.num_gqa_groups,
1990
                transpose_batch_sequence=self.transpose_batch_sequence,
1991
                enable_sequence_parallel=self.enable_sequence_parallel,
1992
                attention_dropout=self.attention_dropout,
1993
1994
1995
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1996
                zero_centered_gamma=self.zero_centered_gamma,
1997
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
1998
1999
2000
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
2001
2002
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2003
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2004
2005
2006
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2007
2008
2009
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
2010
                fuse_qkv_params=self.fuse_qkv_params,
2011
2012
2013
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
2014
                name="encoder_decoder_attention",
2015
                window_size=self.window_size,
2016
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
2017
2018

            y = with_sharding_constraint_by_logical_axes(
2019
2020
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2021
            residual = with_sharding_constraint_by_logical_axes(
2022
2023
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2024

2025
            y = hidden_dropout(y, deterministic)
2026
2027
2028
2029
2030

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

2031
2032
            mlp_input = y + residual

2033
        mlp_input = with_sharding_constraint_by_logical_axes(
2034
2035
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2036

2037
2038
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2039
2040
2041
2042
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2043
            zero_centered_gamma=self.zero_centered_gamma,
2044
2045
2046
2047
            epsilon=self.layernorm_epsilon,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2048
2049
2050
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2051
            dtype=self.dtype,
2052
2053
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2054
            kernel_init=self.mlp_kernel_init,
2055
2056
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2057
2058
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2059
2060
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2061
2062
2063
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2064
2065
2066
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2067
            name="mlp",
2068
2069
2070
2071
2072
2073
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2074
        z = with_sharding_constraint_by_logical_axes(
2075
2076
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2077
        residual = with_sharding_constraint_by_logical_axes(
2078
2079
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2080

2081
2082
2083
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2084
2085
2086
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2087
2088
2089
        z = z + residual

        if self.output_layernorm:
2090
            z = with_sharding_constraint_by_logical_axes(
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2102
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2103
        return z