transformer.py 90.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
19
from flax.linen.attention import combine_masks
20
21
22
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
23
from jax.ad_checkpoint import checkpoint_name
24
25
26

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
27
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
28
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
29
from ..attention import fused_attn
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
42
43
44
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
45
46
47
48
49
50
51
52
53
54
55
56
57
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
58
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
59
60
61
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
62
63
64
65
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
66
67
68
69

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
70
71
72
73
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

74
75
    Parameters
    ----------
76
    rules: Sequence[Tuple[str, Union[str, None]]]
77
78
79
80
        the base Flax logical axis rules to extend.

    Returns
    -------
81
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
82
83
84
85
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
86
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
87
88
        key = item[0]
        val = item[1]
89
90
91
92
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
93
94
95
96
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
97
98

    extended_rules = [*rules]
99
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
100
101
102
        key = item[0]
        val = item[1]
        if key in rules_map:
103
104
105
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
106
                f" rules, but {val} in TE's rules."
107
            )
108
109
110
111
112
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


113
114
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
115
116
117
118
119
120
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
121
    window_size: Optional[Tuple[int, int]] = None
122
123

    @nn.compact
124
125
126
127
128
129
130
131
132
133
134
135
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
136
        batch_dim = 1 if self.transpose_batch_sequence else 0
137
138
139
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
140
        sequence_dim = 0 if self.transpose_batch_sequence else 1
141
142
143
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
144

145
146
        input_dtype = query.dtype

147
148
149
150
151
152
153
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
154
155
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
156
157
158
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
159
        is_gqa = h_q != h_kv
160

161
        if is_gqa:
162
163
164
165
166
167
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
168
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
169
            else:
170
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
171
        else:
172
            if is_gqa:
173
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
174
            else:
175
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
176

177
        attn_weights = checkpoint_name(attn_weights, "logits")
178

179
        if is_gqa:
180
181
182
183
184
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        attn_weights = with_sharding_constraint_by_logical_axes(
185
186
            attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
        )
187
188
189
190
191

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
192
            fused_scale_factor = 1.0
193
        else:
194
195
196
197
198
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

199
        def apply_swa_mask(original_mask: Array) -> Array:
200
            """Apply the sliding window mask to a given mask"""
201
            batch = original_mask.shape[0]
202
203
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
204
205
206
207
208
209
210
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
211
212
            return new_mask

213
214
        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
215
216
217
218
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
219
                mask = None
220
            if mask is not None:
221
                mask = apply_swa_mask(mask)
222
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
223
            if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
224
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
225
226
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
                if mask is not None:
227
228
                    return SoftmaxType.SCALED_MASKED, mask
                return SoftmaxType.SCALED, mask
229
230
231
232
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
233

234
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
235

236
237
        attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
            attn_weights, mask, bias
238
        ).astype(input_dtype)
239

240
        if is_gqa:
241
242
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

243
        if not deterministic and self.attention_dropout > 0.0:
244
245
246
247
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
248
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
249
250
            attn_weights = attn_weights * multiplier

251
252
253
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
254
255
        if self.transpose_batch_sequence:
            if is_gqa:
256
257
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
258
259

        if is_gqa:
260
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
261

262
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
263
264


265
266
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
267
268
269
270
271
272
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
273
    window_size: Optional[Tuple[int, int]] = None
274
    max_segments_per_seq: Optional[int] = 1
275
276
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
277
278

    @nn.compact
279
280
281
282
283
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
284
        sequence_descriptor: Optional[SequenceDescriptor] = None,
285
286
287
288
289
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
290
291
292
293
294
295
296
297
298
299
300

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

301
        if self.qkv_layout.is_qkvpacked():
302
303
304
305
306
307
308
309
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
310
311
            x = fused_attn(
                (qkv_packed,),
312
                bias,
313
                sequence_descriptor,
314
315
316
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
317
                qkv_layout=self.qkv_layout,
318
319
320
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
321
                window_size=self.window_size,
322
                max_segments_per_seq=self.max_segments_per_seq,
323
324
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
325
            )
326
        elif self.qkv_layout.is_kvpacked():
327
328
329
330
331
332
333
334
335
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
336
337
            x = fused_attn(
                (query, kv_packed),
338
                bias,
339
                sequence_descriptor,
340
341
342
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
343
                qkv_layout=self.qkv_layout,
344
345
346
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
347
                window_size=self.window_size,
348
                max_segments_per_seq=self.max_segments_per_seq,
349
350
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
351
            )
352
        elif self.qkv_layout.is_separate():
353
354
355
356
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
357
            x = fused_attn(
358
                (query, key, value),
359
                bias,
360
                sequence_descriptor,
361
362
363
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
364
                qkv_layout=self.qkv_layout,
365
366
367
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
368
                window_size=self.window_size,
369
                max_segments_per_seq=self.max_segments_per_seq,
370
371
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
372
            )
373
374
375
376
377
378
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

379
        assert x.dtype == query.dtype
380
381
382
        return x


383
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

404
405
406
407
408
409
410
411
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

448
449
        .. note:: THD format only supports 'padding' or 'causal_padding' mask type.

450
    attn_bias_type: Optional[str], default = None
451
        Type of the attention bias passed in the attention.
452
453
454
455
456
457
458
459
460
461
462
463
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
464
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
465
466
467
468
469
470

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
471
472
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
489
490
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
491
492
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
493
494
495
    context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis (str): The name of the context parallel axis.
496
497
498

    Optimization parameters
    -----------------------
499
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
500
        The data type used to allocate the initial parameters.
501
    """
502

503
504
505
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
506
507
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
508
509
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
510
    dropout_rng_name: str = "dropout"
511
    float32_logits: bool = False
512
    qkv_layout: str = "bshd_bshd_bshd"
513
514
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
515
    window_size: Optional[Tuple[int, int]] = None
516
    max_segments_per_seq: Optional[int] = 1
517
518
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
519
520

    @nn.compact
521
522
523
524
525
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
526
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
527
528
529
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
530
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
531
    ) -> Array:
532
533
534
535
536
537
538
539
540
541
542
543
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
544
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
545
546
547
548
549
550
551
552
553
554
555
556
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
557
        input_dtype = query.dtype
558

559
560
561
562
563
564
565
566
567
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]

591
592
593
594
595
596
597
598
599
600
601
602
        has_fused_attn_kernel = is_fused_attn_kernel_available(
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
            self.head_dim,
603
            self.window_size,
604
        )
605

606
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
607
608

        if enable_fused_attn and not has_fused_attn_kernel:
609
610
611
612
613
614
615
616
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n"
            )
617
618

        dropout_rng = None
619
        if not deterministic and self.attention_dropout > 0.0:
620
621
622
623
624
625
626
627
628
629
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(self.head_dim)
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
630
            if qkv_layout.is_qkvpacked():
631
                query, key, value = jnp.split(query, [1, 2], axis=-3)
632
633
634
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
635
            elif qkv_layout.is_kvpacked():
636
637
638
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
639
640
                assert qkv_layout.is_separate()

641
642
643
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
644

645
646
647
648
649
650
651
652
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
653
                window_size=self.window_size,
654
655
656
657
658
659
660
661
662
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
663
664
665
666
667
668
669
670
671
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
672
                window_size=self.window_size,
673
                max_segments_per_seq=self.max_segments_per_seq,
674
675
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
676
677
678
679
680
681
682
683
684
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
685
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
686
        return x
687
688


689
690
691
692
693
694
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
695
696
697
    """
    Rotary Positional Embedding
    x should be in shape of
698
699
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
700
    """
701
702
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
703
704
705
    min_window = windows[0]
    max_window = windows[1]

706
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
707
    time_scales = min_window * (max_window / min_window) ** fraction
708
709
710
711
712
713
714
715
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

716
717
718
719
720
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
721

722
723
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
724

725
        x1, x2 = jnp.split(x, 2, axis=-1)
726
727
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
728

729
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
753
754
755
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
756
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
757
        )
758
759
760
761
762

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

763
    if group_method == "alternate":
764
765
        return alternate_impl()
    return consecutive_impl()
766
767


768
class LoRAScope:  # pylint: disable=too-few-public-methods
769
770
771
772
773
774
775
776
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
777
778
779
780
781
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
782
783
784
785


def _canonicalize_lora_scope(scope):

786
787
788
789
790
791
792
793
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
794
795
796
797
798
799

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
800
801
802
803
804
805
806
807
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


824
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
825
826
827
828
829
830
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
831
    head_dim: int
832
        The hidden dimension of each attention head.
833
834
835
836
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
837
838
839
840
841
842
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
843
844
845
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

865
866
867
868
869
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
870
    dropout_rng_name: str, default = 'dropout'
871
872
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
873
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
874
        Indicate the type of layer normalization.
875
    layernorm_epsilon: float, default = 1e-6
876
        A value added to the denominator of layer normalization for numerical stability.
877
    zero_centered_gamma: bool, default = False
878
879
880
881
882
883
884
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
885
886
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
887
        Used for initializing the QKV and output projection weights.
888
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
889
    use_bias: bool, default = False
890
        Indicate whether or not to enable bias shifting for QKV and output projections.
891
        If set to False, the layer will not learn additive biases.
892
    bias_init: Initializer, default = flax.linen.initializers.zeros
893
894
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
895
896
897
898
899
900
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
901
902
903
904
905
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
906
907
908
909
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
910
911
912
913
914
915
916
917
918
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
919
920
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
921
922
923
924
925
926
927
928
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
929
930
931

    Optimization parameters
    -----------------------
932
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
933
        The data type used to allocate the initial parameters.
934
    fuse_qkv_params: bool, default = True
935
        If set to True, this module exposes a single fused
936
937
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
938
    transpose_batch_sequence: bool, default = True
939
        Indicate whether the input tensors were switched axis of batch
940
941
942
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
943
        Indicate whether to scale attention logits.
944
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
945
        else :math:`Q*K`
946
947
948
949
950
951
952
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
953
954
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
955
956
957
    """

    head_dim: int
958
959
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
960
961
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
962
    input_layernorm: bool = True
963
964
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
965
    return_layernorm_output: bool = False
966
    zero_centered_gamma: bool = False
967
968
969
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
970
    attn_mask_type: str = "causal"
971
    attn_bias_type: Optional[str] = None
972
973
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
974
975
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
976
977
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
978
    dtype: DType = jnp.float32
979
    fuse_qkv_params: bool = True
980
    transpose_batch_sequence: bool = True
981
    enable_sequence_parallel: bool = False
982
983
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
984
    float32_logits: bool = False
985
    window_size: Optional[Tuple[int, int]] = None
986
987
988
989
990
991
992

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
993
994

    def __post_init__(self):
995
996
997
998
999
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1000
1001
1002
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1003
1004
1005
1006
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1007
1008
1009
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1010
1011
1012
1013
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1014
1015
                DeprecationWarning,
            )
1016
1017
1018
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1019
1020
1021
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1022
1023
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1024
1025
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1026

1027
        if self.kernel_init is None:
1028
            self.kernel_init = nn.initializers.variance_scaling(
1029
                1.0, "fan_in", "normal", dtype=self.dtype
1030
            )
zlsh80826's avatar
zlsh80826 committed
1031
        if self.num_gqa_groups is None:
1032
            self.num_gqa_groups = self.num_attention_heads
1033
1034
1035
        super().__post_init__()

    @nn.compact
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1046
1047
1048
1049
1050
1051
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1052
        inputs_q: jax.numpy.ndarray
1053
            Input tensor for query projection.
1054
        inputs_kv: jax.numpy.ndarray
1055
            Input tensor for key/value projection.
1056
1057
1058
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
1059
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
1060
1061
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1062
        *
1063
        decode: bool, default = False
1064
            Indicate whether to prepare and use an autoregressive cache.
1065
        deterministic: bool, default = False
1066
1067
1068
1069
            Disable dropout layers if set to True.

        Returns
        -------
1070
        outputs: jax.numpy.ndarray
1071
1072
            Output tensors.
        """
1073

1074
1075
1076
1077
1078
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1079
        def query_init(*args):
1080
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1123
1124
1125
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1126

1127
1128
1129
1130
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1131
1132
1133
1134
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1135
1136
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1137
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1138
            if is_qkvpack:
1139
                qkv_proj, ln_out = LayerNormDenseGeneral(
1140
                    enable_layernorm=self.input_layernorm,
1141
                    layernorm_type=self.layernorm_type,
1142
                    zero_centered_gamma=self.zero_centered_gamma,
1143
1144
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1145
                    features=(3, self.num_attention_heads * self.head_dim),
1146
                    transpose_batch_sequence=self.transpose_batch_sequence,
1147
                    return_layernorm_output=self.return_layernorm_output,
1148
1149
1150
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1151
1152
1153
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1154
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1155
1156
1157
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1158
1159
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1160
1161
1162
1163
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1164
                qkv_layout = QKVLayout.BS3HD
1165
1166
            else:
                query, ln_out = LayerNormDenseGeneral(
1167
                    enable_layernorm=self.input_layernorm,
1168
                    layernorm_type=self.layernorm_type,
1169
                    zero_centered_gamma=self.zero_centered_gamma,
1170
1171
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1172
                    features=self.num_attention_heads * self.head_dim,
1173
                    transpose_batch_sequence=self.transpose_batch_sequence,
1174
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1175
1176
1177
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1178
1179
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1180
                    bias_axes=(W_TP_AXES,),
1181
1182
1183
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1184
1185
                    dtype=self.dtype,
                    kernel_init=query_init,
1186
1187
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1188
1189
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1190
1191
1192
1193
1194

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1211
                qkv_layout = QKVLayout.BSHD_BS2HD
1212
1213
1214
1215
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1216
                features=self.num_gqa_groups * self.head_dim,
1217
                transpose_batch_sequence=self.transpose_batch_sequence,
1218
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1219
1220
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1221
                bias_axes=(W_TP_AXES,),
1222
1223
1224
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1225
1226
                dtype=self.dtype,
            )
1227
            query, ln_out = LayerNormDenseGeneral(
1228
                enable_layernorm=self.input_layernorm,
1229
                layernorm_type=self.layernorm_type,
1230
                zero_centered_gamma=self.zero_centered_gamma,
1231
1232
                epsilon=self.layernorm_epsilon,
                axis=-1,
1233
                features=self.num_attention_heads * self.head_dim,
1234
1235
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
1236
1237
1238
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1239
1240
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1241
                bias_axes=(W_TP_AXES,),
1242
1243
1244
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1245
1246
                dtype=self.dtype,
                kernel_init=query_init,
1247
1248
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1249
1250
                name="query",
            )(inputs_q)
1251

1252
            if is_self_attn:
1253
1254
1255
                assert ln_out is not None
                inputs_kv = ln_out

1256
            query = query.astype(input_dtype)
1257
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1258
            key = key.astype(input_dtype)
1259
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1260
            value = value.astype(input_dtype)
1261
1262
1263
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1264
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1265

1266
        if self.enable_rotary_pos_emb:
1267
1268
1269
1270
1271
1272
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1273

1274
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1275
1276
1277
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1290
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1291

1292
1293
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1294
1295
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1296
1297

        if decode:
1298
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1299
1300
1301
1302
1303
1304
1305
1306
1307
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1308
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1309
                if self.transpose_batch_sequence:
1310
1311
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1312
1313
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1314
1315
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1316
                    one_hot_indices_shape = (1, length, 1, 1)
1317
1318
1319
1320

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1321
1322
1323
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1324

1325
                cur_index = cache_index.value.astype(jnp.int32)
1326
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1327
1328
1329
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1330
1331
1332
1333
1334
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1335
1336
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1337
1338

                if bias is not None:
1339
1340
1341
1342
1343
1344
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1345

1346
1347
1348
1349
1350
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1351
1352
1353
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1365
        else:
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1376
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1390
            window_size=self.window_size,
1391
        )(*dpa_args, mask, bias, deterministic=deterministic)
1392
1393
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1394
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1395
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1396

1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            transpose_batch_sequence=self.transpose_batch_sequence,
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1413

1414
1415
1416
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1417
        return out, ln_out
1418
1419


1420
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1421
1422
1423
1424
1425
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1426
    num_buckets: int
1427
        The number of buckets to bucket distances between key and query positions into.
1428
    max_distance: int
1429
        The maximum distance before everything is lumped into the last
1430
        distance bucket.
1431
    num_attention_heads: int
1432
        Number of attention heads in the transformer layer.
1433
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1434
        Used for initializing relative embedding tables.
1435
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1436
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1437
1438
1439

    Optimization parameters
    -----------------------
1440
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1441
        The data type used to allocate the initial parameters.
1442
    """
1443

1444
1445
1446
1447
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1448
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1449
1450
1451
1452
1453
1454
1455
1456
1457
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1458
        q_seqlen: int
1459
            The sequence length of query.
1460
        k_seqlen: int
1461
            The sequence length of key.
1462
        bidirectional: bool, default = True
1463
            Indicate whether to allow positive memory-query relative position
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1490
1491
1492
1493
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1494
1495
1496
1497
1498
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
1499
1500
1501
            "rel_embedding",
            self.embedding_init,
            (self.num_attention_heads, self.num_buckets),
1502
            self.dtype,
1503
1504
            axes=self.embedding_axes,
        )
1505
1506
1507
1508
1509
1510

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1511
1512
1513
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1514
1515
1516
1517
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1528

1529
1530
1531
1532
    ENCODER = "encoder"
    DECODER = "decoder"


1533
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1534
1535
1536
1537
1538
1539
1540
1541
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1542
        The hidden size of each input sample.
1543
    mlp_hidden_size: int, default = 2048
1544
        Intermediate size to which input samples are projected.
1545
    num_attention_heads: int, default = 8
1546
        Number of attention heads in the transformer layer.
1547
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1548
1549
1550
1551
1552
1553
1554
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1555
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1556
        Indicate the type of layer normalization.
1557
    layernorm_epsilon: float, default = 1e-6
1558
        A value added to the denominator of layer normalization for numerical stability.
1559
    zero_centered_gamma: bool, default = False
1560
1561
1562
1563
1564
1565
1566
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1567
    hidden_dropout: float, default = 0.1
1568
        Dropout probability for the dropout op after FC2 layer.
1569
    hidden_dropout_dims: Sequence[int], default = ()
1570
        Dimensions that will share the same dropout mask for hidden
1571
    attention_dropout: float, default = 0.1
1572
        Dropout probability for the dropout op during multi-head attention.
1573
1574
1575
1576
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1577
    dropout_rng_name: str, default = 'dropout'
1578
1579
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1580
1581
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1582
1583
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1584
1585
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1586
1587
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1588
    mlp_activations: Sequence[str], default = ('relu', )
1589
        The sequence of activation functions to apply after the first linear transformation.
1590
1591
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1592
1593
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1594
    bias_init: Initializer, default = flax.linen.initializers.zeros
1595
1596
1597
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1598
    apply_residual_connection_post_layernorm: bool, default = False
1599
        If set to True, residual connections are taken from the output
1600
1601
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1602
        If set to True, layer normalization is applied on the output side,
1603
1604
1605
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1606
1607
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1608
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1609
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1610
1611
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1612
    self_attn_mask_type: str, default = 'causal'
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1632
1633
1634
1635
1636
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1637
    enable_relative_embedding: bool, default = True
1638
        Whether to enable relative embedding as shifting of attention logits.
1639
    relative_embedding: flax.linen.Module, default = None
1640
        The module for relative embedding execution, only used when
1641
1642
1643
1644
1645
1646
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1647
1648
1649
1650
1651
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1652
    rotary_pos_emb_group_method: str, default = 'consecutive'
1653
1654
1655
1656
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1657
1658
1659
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1660
        'exclude_output_proj', 'exclude_mlp']
1661
1662
1663
1664
1665
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1666
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1667
1668
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1669
1670
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1671
1672
1673

    Optimization parameters
    -----------------------
1674
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1675
        The data type used to allocate the initial parameters.
1676
    drop_path: float, default = 0.0
1677
        When > 0.0, applies stochastic depth per sample in the main
1678
1679
        path of the residual block.
    fuse_qkv_params: bool, default = True
1680
        If set to True, `TransformerLayer` module exposes a single fused
1681
1682
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1683
    transpose_batch_sequence: bool, default = False
1684
        Indicate whether the input tensors were switched axis of batch
1685
1686
1687
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1688
        Indicate whether to scale attention logits.
1689
1690
1691
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1692
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1693
1694
1695
1696
1697
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1698
    num_gqa_groups: Optional[int] = None
1699
    layernorm_type: str = "layernorm"
1700
    layernorm_epsilon: float = 1e-6
1701
    zero_centered_gamma: bool = False
1702
1703
1704
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1705
1706
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1707
    dropout_rng_name: str = "dropout"
1708
1709
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1710
    mlp_activations: Sequence[str] = ("relu",)
1711
1712
1713
1714
1715
1716
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1717
    self_attn_mask_type: str = "causal"
1718
    self_attn_bias_type: Optional[str] = None
1719
1720
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1721
1722
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1723
1724
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1725
1726
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1727
1728
1729
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1730
    transpose_batch_sequence: bool = False
1731
    enable_sequence_parallel: bool = False
1732
1733
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1734
    window_size: Optional[Tuple[int, int]] = None
1735
1736
1737

    def __post_init__(self):
        if self.mha_kernel_init is None:
1738
            self.mha_kernel_init = nn.initializers.variance_scaling(
1739
                1.0, "fan_in", "normal", dtype=self.dtype
1740
            )
1741
        if self.mlp_kernel_init is None:
1742
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1743
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1744
            )
zlsh80826's avatar
zlsh80826 committed
1745
1746
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1747
1748
1749
        super().__post_init__()

    @nn.compact
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1760
1761
1762
1763
1764
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1765
        inputs: jax.numpy.ndarray
1766
            Input tensor.
1767
        encoded: jax.numpy.ndarray, default = None
1768
1769
1770
1771
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1772
1773
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1774
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1775
1776
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1777
            :attr:`True` means mask out the corresponding values.
1778
        deterministic: bool, default = False
1779
            Disable dropout layers if set to True.
1780
        decode: bool, default = False
1781
1782
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1783
        max_decode_length: bool, default = None
1784
1785
1786
1787
1788
1789
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1790
        outputs: jax.numpy.ndarray
1791
            Output tensors.
1792
        """
1793

1794
        input_dtype = inputs.dtype
1795
1796
1797
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1798

1799
1800
1801
1802
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1803

1804
1805
1806
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1807
1808
1809
1810
1811
1812

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1813
1814
1815
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1816
1817
1818
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1819
1820
1821
1822
1823

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1824
1825
1826
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1827
1828
1829
1830
1831
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
1832
1833
1834
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
1835
1836
                    name="relpos_bias",
                )
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
1854
            mha_name = "attention"
1855
        else:
1856
            mha_name = "self_attention"
1857

1858
        inputs = with_sharding_constraint_by_logical_axes(
1859
1860
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1861

1862
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1863
1864
1865
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1866
1867
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1868
            num_gqa_groups=self.num_gqa_groups,
1869
            transpose_batch_sequence=self.transpose_batch_sequence,
1870
            enable_sequence_parallel=self.enable_sequence_parallel,
1871
            attention_dropout=self.attention_dropout,
1872
1873
1874
1875
1876
1877
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1878
            zero_centered_gamma=self.zero_centered_gamma,
1879
1880
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1881
            attn_mask_type=self.self_attn_mask_type,
1882
            attn_bias_type=self.self_attn_bias_type,
1883
1884
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1885
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1886
1887
1888
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1889
            fuse_qkv_params=self.fuse_qkv_params,
1890
1891
1892
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1893
            name=mha_name,
1894
            window_size=self.window_size,
1895
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
1896
1897
1898
1899
1900

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1901
                assert -x_shape_len <= dims < x_shape_len
1902

1903
1904
1905
1906
1907
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1908

1909
        x = with_sharding_constraint_by_logical_axes(
1910
1911
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1912
        residual = with_sharding_constraint_by_logical_axes(
1913
1914
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1915

1916
1917
1918
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1919
1920
1921
1922
1923
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1924
1925
1926
1927
1928

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1929
1930
1931
1932
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
1933
1934
1935
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
1936

1937
            x = with_sharding_constraint_by_logical_axes(
1938
1939
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1940

1941
1942
1943
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1944
1945
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1946
                num_gqa_groups=self.num_gqa_groups,
1947
                transpose_batch_sequence=self.transpose_batch_sequence,
1948
                enable_sequence_parallel=self.enable_sequence_parallel,
1949
                attention_dropout=self.attention_dropout,
1950
1951
1952
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1953
                zero_centered_gamma=self.zero_centered_gamma,
1954
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
1955
1956
1957
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
1958
1959
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1960
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1961
1962
1963
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1964
1965
1966
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
1967
                fuse_qkv_params=self.fuse_qkv_params,
1968
1969
1970
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1971
                name="encoder_decoder_attention",
1972
                window_size=self.window_size,
1973
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
1974
1975

            y = with_sharding_constraint_by_logical_axes(
1976
1977
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1978
            residual = with_sharding_constraint_by_logical_axes(
1979
1980
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1981

1982
            y = hidden_dropout(y, deterministic)
1983
1984
1985
1986
1987

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

1988
1989
            mlp_input = y + residual

1990
        mlp_input = with_sharding_constraint_by_logical_axes(
1991
1992
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1993

1994
1995
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1996
1997
1998
1999
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2000
            zero_centered_gamma=self.zero_centered_gamma,
2001
2002
2003
2004
2005
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2006
2007
2008
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2009
            dtype=self.dtype,
2010
2011
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2012
            kernel_init=self.mlp_kernel_init,
2013
2014
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2015
2016
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2017
2018
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2019
2020
2021
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2022
2023
2024
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2025
            name="mlp",
2026
2027
2028
2029
2030
2031
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2032
        z = with_sharding_constraint_by_logical_axes(
2033
2034
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2035
        residual = with_sharding_constraint_by_logical_axes(
2036
2037
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2038

2039
2040
2041
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2042
2043
2044
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2045
2046
2047
        z = z + residual

        if self.output_layernorm:
2048
            z = with_sharding_constraint_by_logical_axes(
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                transpose_batch_sequence=self.transpose_batch_sequence,
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2061
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2062
        return z