transformer.py 90.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
19
from flax.linen.attention import combine_masks
20
21
22
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
23
from jax.ad_checkpoint import checkpoint_name
24
25
26

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
27
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
28
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
29
from ..attention import fused_attn
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
42
43
44
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
45
46
47
48
49
50
51
52
53
54
55
56
57
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
58
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
59
60
61
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
62
63
64
65
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
66
67
68
69

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
70
71
72
73
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

74
75
    Parameters
    ----------
76
    rules: Sequence[Tuple[str, Union[str, None]]]
77
78
79
80
        the base Flax logical axis rules to extend.

    Returns
    -------
81
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
82
83
84
85
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
86
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
87
88
        key = item[0]
        val = item[1]
89
90
91
92
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
93
94
95
96
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
97
98

    extended_rules = [*rules]
99
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
100
101
102
        key = item[0]
        val = item[1]
        if key in rules_map:
103
104
105
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
106
                f" rules, but {val} in TE's rules."
107
            )
108
109
110
111
112
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


113
114
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
115
116
117
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
118
    weight_dtype: DType = jnp.float32
119
120
121
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
122
    window_size: Optional[Tuple[int, int]] = None
123
124

    @nn.compact
125
126
127
128
129
130
131
132
133
134
135
136
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
137
        batch_dim = 1 if self.transpose_batch_sequence else 0
138
139
140
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
141
        sequence_dim = 0 if self.transpose_batch_sequence else 1
142
143
144
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
145
146
147
148
149
150
151
152

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
153
154
            query = query.astype(self.dtype)
            key = key.astype(self.dtype)
155
156
157
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
158
        is_gqa = h_q != h_kv
159

160
        if is_gqa:
161
162
163
164
165
166
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
167
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
168
            else:
169
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
170
        else:
171
            if is_gqa:
172
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
173
            else:
174
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
175

176
        attn_weights = checkpoint_name(attn_weights, "logits")
177

178
        if is_gqa:
179
180
181
182
183
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        attn_weights = with_sharding_constraint_by_logical_axes(
184
185
            attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
        )
186
187
188
189
190

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
191
            fused_scale_factor = 1.0
192
        else:
193
194
195
196
197
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

198
        def apply_swa_mask(original_mask: Array) -> Array:
199
            """Apply the sliding window mask to a given mask"""
200
            batch = original_mask.shape[0]
201
202
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
203
204
205
206
207
208
209
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
210
211
            return new_mask

212
213
        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
214
215
216
217
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
218
                mask = None
219
            if mask is not None:
220
                mask = apply_swa_mask(mask)
221
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
222
            if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
223
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
224
225
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
                if mask is not None:
226
227
                    return SoftmaxType.SCALED_MASKED, mask
                return SoftmaxType.SCALED, mask
228
229
230
231
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
232

233
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
234

235
236
237
        attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
            attn_weights, mask, bias
        ).astype(self.dtype)
238

239
        if is_gqa:
240
241
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

242
        if not deterministic and self.attention_dropout > 0.0:
243
244
245
246
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
247
            multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
248
249
250
251
            attn_weights = attn_weights * multiplier

        if self.transpose_batch_sequence:
            if is_gqa:
252
253
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
254
255

        if is_gqa:
256
257
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
258
259


260
261
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
262
263
264
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
265
    weight_dtype: DType = jnp.float32
266
267
268
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
269
    window_size: Optional[Tuple[int, int]] = None
270
    max_segments_per_seq: Optional[int] = 1
271
272
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
273
274

    @nn.compact
275
276
277
278
279
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
280
        sequence_descriptor: Optional[SequenceDescriptor] = None,
281
282
283
284
285
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
286
287
288
289
290
291
292
293
294
295
296

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

297
        if self.qkv_layout.is_qkvpacked():
298
299
300
301
302
303
304
305
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
306
307
            x = fused_attn(
                (qkv_packed,),
308
                bias,
309
                sequence_descriptor,
310
311
312
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
313
                qkv_layout=self.qkv_layout,
314
315
316
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
317
                window_size=self.window_size,
318
                max_segments_per_seq=self.max_segments_per_seq,
319
320
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
321
            )
322
        elif self.qkv_layout.is_kvpacked():
323
324
325
326
327
328
329
330
331
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
332
333
            x = fused_attn(
                (query, kv_packed),
334
                bias,
335
                sequence_descriptor,
336
337
338
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
339
                qkv_layout=self.qkv_layout,
340
341
342
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
343
                window_size=self.window_size,
344
                max_segments_per_seq=self.max_segments_per_seq,
345
346
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
347
            )
348
        elif self.qkv_layout.is_separate():
349
350
351
352
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
353
            x = fused_attn(
354
                (query, key, value),
355
                bias,
356
                sequence_descriptor,
357
358
359
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
360
                qkv_layout=self.qkv_layout,
361
362
363
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
364
                window_size=self.window_size,
365
                max_segments_per_seq=self.max_segments_per_seq,
366
367
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
368
            )
369
370
371
372
373
374
375
376
377
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

        return x


378
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

399
400
401
402
403
404
405
406
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

443
444
        .. note:: THD format only supports 'padding' or 'causal_padding' mask type.

445
    attn_bias_type: Optional[str], default = None
446
        Type of the attention bias passed in the attention.
447
448
449
450
451
452
453
454
455
456
457
458
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
459
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
460
461
462
463
464
465

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
466
467
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
484
485
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
486
487
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
488
489
490
    context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis (str): The name of the context parallel axis.
491
492
493

    Optimization parameters
    -----------------------
494
495
496
497
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
498
    """
499

500
501
502
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
503
504
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
505
506
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
507
    weight_dtype: DType = jnp.float32
508
    dropout_rng_name: str = "dropout"
509
    float32_logits: bool = False
510
    qkv_layout: str = "bshd_bshd_bshd"
511
512
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
513
    window_size: Optional[Tuple[int, int]] = None
514
    max_segments_per_seq: Optional[int] = 1
515
516
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
517
518

    @nn.compact
519
520
521
522
523
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
524
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
525
526
527
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
528
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
529
    ) -> Array:
530
531
532
533
534
535
536
537
538
539
540
541
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
542
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
543
544
545
546
547
548
549
550
551
552
553
554
555
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """

556
557
558
559
560
561
562
563
564
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]

588
589
590
591
592
593
594
595
596
597
598
599
        has_fused_attn_kernel = is_fused_attn_kernel_available(
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
            self.head_dim,
600
            self.window_size,
601
        )
602

603
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
604
605

        if enable_fused_attn and not has_fused_attn_kernel:
606
607
608
609
610
611
612
613
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n"
            )
614
615

        dropout_rng = None
616
        if not deterministic and self.attention_dropout > 0.0:
617
618
619
620
621
622
623
624
625
626
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(self.head_dim)
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
627
            if qkv_layout.is_qkvpacked():
628
                query, key, value = jnp.split(query, [1, 2], axis=-3)
629
630
631
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
632
            elif qkv_layout.is_kvpacked():
633
634
635
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
636
637
638
                assert qkv_layout.is_separate()

            assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray)
639

640
641
642
643
644
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
645
                weight_dtype=self.weight_dtype,
646
647
648
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
649
                window_size=self.window_size,
650
651
652
653
654
655
656
657
658
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
659
660
661
662
663
664
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
665
                weight_dtype=self.weight_dtype,
666
667
668
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
669
                window_size=self.window_size,
670
                max_segments_per_seq=self.max_segments_per_seq,
671
672
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
673
674
675
676
677
678
679
680
681
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
682

683
        return x
684
685


686
687
688
689
690
691
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
692
693
694
    """
    Rotary Positional Embedding
    x should be in shape of
695
696
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
697
    """
698
699
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
700
701
702
    min_window = windows[0]
    max_window = windows[1]

703
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
704
    time_scales = min_window * (max_window / min_window) ** fraction
705
706
707
708
709
710
711
712
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

713
714
715
716
717
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
718

719
720
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
721

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        x1, x2 = jnp.split(x, 2, axis=-1)
        part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(x.dtype)

        output = jnp.concatenate([part_1, part_2], axis=-1)
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
750
751
752
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
753
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
754
        )
755
756
757
758
759

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

760
    if group_method == "alternate":
761
762
        return alternate_impl()
    return consecutive_impl()
763
764


765
class LoRAScope:  # pylint: disable=too-few-public-methods
766
767
768
769
770
771
772
773
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
774
775
776
777
778
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
779
780
781
782


def _canonicalize_lora_scope(scope):

783
784
785
786
787
788
789
790
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
791
792
793
794
795
796

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
797
798
799
800
801
802
803
804
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


821
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
822
823
824
825
826
827
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
828
    head_dim: int
829
        The hidden dimension of each attention head.
830
831
832
833
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
834
835
836
837
838
839
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
840
841
842
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

862
863
864
865
866
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
867
    dropout_rng_name: str, default = 'dropout'
868
869
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
870
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
871
        Indicate the type of layer normalization.
872
    layernorm_epsilon: float, default = 1e-6
873
        A value added to the denominator of layer normalization for numerical stability.
874
    zero_centered_gamma: bool, default = False
875
876
877
878
879
880
881
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
882
883
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
884
        Used for initializing the QKV and output projection weights.
885
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
886
    use_bias: bool, default = False
887
        Indicate whether or not to enable bias shifting for QKV and output projections.
888
        If set to False, the layer will not learn additive biases.
889
    bias_init: Initializer, default = flax.linen.initializers.zeros
890
891
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
892
893
894
895
896
897
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
898
899
900
901
902
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
903
904
905
906
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
907
908
909
910
911
912
913
914
915
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
916
917
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
918
919
920
921
922
923
924
925
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
926
927
928

    Optimization parameters
    -----------------------
929
930
931
932
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
933
    fuse_qkv_params: bool, default = True
934
        If set to True, this module exposes a single fused
935
936
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
937
    transpose_batch_sequence: bool, default = True
938
        Indicate whether the input tensors were switched axis of batch
939
940
941
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
942
        Indicate whether to scale attention logits.
943
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
944
        else :math:`Q*K`
945
946
947
948
949
950
951
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
952
953
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
954
955
956
    """

    head_dim: int
957
958
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
959
960
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
961
    input_layernorm: bool = True
962
963
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
964
    return_layernorm_output: bool = False
965
    zero_centered_gamma: bool = False
966
967
968
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
969
    attn_mask_type: str = "causal"
970
    attn_bias_type: Optional[str] = None
971
972
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
973
974
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
975
976
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
977
    dtype: DType = jnp.float32
978
    weight_dtype: DType = jnp.float32
979
    fuse_qkv_params: bool = True
980
    transpose_batch_sequence: bool = True
981
    enable_sequence_parallel: bool = False
982
983
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
984
    float32_logits: bool = False
985
    window_size: Optional[Tuple[int, int]] = None
986
987
988
989
990
991
992

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
993
994

    def __post_init__(self):
995
996
997
998
999
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1000
1001
1002
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1003
1004
1005
1006
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1007
1008
1009
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1010
1011
1012
1013
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1014
1015
                DeprecationWarning,
            )
1016
1017
1018
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1019
1020
1021
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1022
1023
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1024
1025
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1026

1027
        if self.kernel_init is None:
1028
            self.kernel_init = nn.initializers.variance_scaling(
Phuong Nguyen's avatar
Phuong Nguyen committed
1029
                1.0, "fan_in", "normal", dtype=self.weight_dtype
1030
            )
zlsh80826's avatar
zlsh80826 committed
1031
        if self.num_gqa_groups is None:
1032
            self.num_gqa_groups = self.num_attention_heads
1033
1034
1035
        super().__post_init__()

    @nn.compact
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1046
1047
1048
1049
1050
1051
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1052
        inputs_q: jax.numpy.ndarray
1053
            Input tensor for query projection.
1054
        inputs_kv: jax.numpy.ndarray
1055
            Input tensor for key/value projection.
1056
1057
1058
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
1059
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
1060
1061
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1062
        *
1063
        decode: bool, default = False
1064
            Indicate whether to prepare and use an autoregressive cache.
1065
        deterministic: bool, default = False
1066
1067
1068
1069
            Disable dropout layers if set to True.

        Returns
        -------
1070
        outputs: jax.numpy.ndarray
1071
1072
            Output tensors.
        """
1073
1074

        def query_init(*args):
1075
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1118
1119
1120
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1121

1122
1123
1124
1125
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1126
1127
1128
1129
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1130
1131
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1132
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1133
            if is_qkvpack:
1134
                qkv_proj, ln_out = LayerNormDenseGeneral(
1135
                    enable_layernorm=self.input_layernorm,
1136
                    layernorm_type=self.layernorm_type,
1137
                    zero_centered_gamma=self.zero_centered_gamma,
1138
1139
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1140
                    features=(3, self.num_attention_heads * self.head_dim),
1141
                    transpose_batch_sequence=self.transpose_batch_sequence,
1142
                    return_layernorm_output=self.return_layernorm_output,
1143
1144
1145
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1146
1147
1148
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1149
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1150
1151
1152
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1153
1154
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1155
1156
                    name="qkv",
                    dtype=self.dtype,
1157
                    weight_dtype=self.weight_dtype,
1158
1159
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1160
                qkv_layout = QKVLayout.BS3HD
1161
1162
            else:
                query, ln_out = LayerNormDenseGeneral(
1163
                    enable_layernorm=self.input_layernorm,
1164
                    layernorm_type=self.layernorm_type,
1165
                    zero_centered_gamma=self.zero_centered_gamma,
1166
1167
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1168
                    features=self.num_attention_heads * self.head_dim,
1169
                    transpose_batch_sequence=self.transpose_batch_sequence,
1170
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1171
1172
1173
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1174
1175
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1176
                    bias_axes=(W_TP_AXES,),
1177
1178
1179
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1180
                    dtype=self.dtype,
1181
                    weight_dtype=self.weight_dtype,
1182
                    kernel_init=query_init,
1183
1184
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1185
1186
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1187
1188
1189
1190
1191

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
                    name="kv",
                    dtype=self.dtype,
1206
                    weight_dtype=self.weight_dtype,
1207
1208
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1209
                qkv_layout = QKVLayout.BSHD_BS2HD
1210
1211
1212
1213
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1214
                features=self.num_gqa_groups * self.head_dim,
1215
                transpose_batch_sequence=self.transpose_batch_sequence,
1216
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1217
1218
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1219
                bias_axes=(W_TP_AXES,),
1220
1221
1222
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1223
                dtype=self.dtype,
1224
                weight_dtype=self.weight_dtype,
1225
            )
1226
            query, ln_out = LayerNormDenseGeneral(
1227
                enable_layernorm=self.input_layernorm,
1228
                layernorm_type=self.layernorm_type,
1229
                zero_centered_gamma=self.zero_centered_gamma,
1230
1231
                epsilon=self.layernorm_epsilon,
                axis=-1,
1232
                features=self.num_attention_heads * self.head_dim,
1233
1234
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
1235
1236
1237
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1238
1239
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1240
                bias_axes=(W_TP_AXES,),
1241
1242
1243
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1244
                dtype=self.dtype,
1245
                weight_dtype=self.weight_dtype,
1246
                kernel_init=query_init,
1247
1248
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1249
1250
                name="query",
            )(inputs_q)
1251

1252
            if is_self_attn:
1253
1254
1255
                assert ln_out is not None
                inputs_kv = ln_out

1256
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1257
            key = key.astype(self.dtype)
1258
1259
1260
1261
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1262
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1263

1264
        if self.enable_rotary_pos_emb:
1265
1266
1267
1268
1269
1270
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1271

1272
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1273
1274
1275
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1288
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1289

1290
1291
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1292
1293
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1294
1295

        if decode:
1296
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1297
1298
1299
1300
1301
1302
1303
1304
1305
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1306
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1307
                if self.transpose_batch_sequence:
1308
1309
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1310
1311
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1312
1313
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1314
                    one_hot_indices_shape = (1, length, 1, 1)
1315
1316
1317
1318

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1319
1320
1321
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1322

1323
                cur_index = cache_index.value.astype(jnp.int32)
1324
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1325
1326
1327
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1328
1329
1330
1331
1332
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1333
1334
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1335
1336

                if bias is not None:
1337
1338
1339
1340
1341
1342
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1343

1344
1345
1346
1347
1348
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1349
1350
1351
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1363
        else:
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1374
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1375
1376
1377
1378
1379
1380
1381
1382
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
1383
            weight_dtype=self.weight_dtype,
1384
1385
1386
1387
1388
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1389
            window_size=self.window_size,
1390
        )(*dpa_args, mask, bias, deterministic=deterministic)
1391
1392
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1393
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1394
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1395

1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            transpose_batch_sequence=self.transpose_batch_sequence,
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
1409
            weight_dtype=self.weight_dtype,
1410
1411
1412
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1413
1414

        return out, ln_out
1415
1416


1417
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1418
1419
1420
1421
1422
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1423
    num_buckets: int
1424
        The number of buckets to bucket distances between key and query positions into.
1425
    max_distance: int
1426
        The maximum distance before everything is lumped into the last
1427
        distance bucket.
1428
    num_attention_heads: int
1429
        Number of attention heads in the transformer layer.
1430
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1431
        Used for initializing relative embedding tables.
1432
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1433
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1434
1435
1436

    Optimization parameters
    -----------------------
1437
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1438
1439
1440
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
1441
    """
1442

1443
1444
1445
1446
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1447
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1448
    dtype: DType = jnp.float32
1449
    weight_dtype: DType = jnp.float32
1450
1451
1452
1453
1454
1455
1456
1457

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1458
        q_seqlen: int
1459
            The sequence length of query.
1460
        k_seqlen: int
1461
            The sequence length of key.
1462
        bidirectional: bool, default = True
1463
            Indicate whether to allow positive memory-query relative position
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1490
1491
1492
1493
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1494
1495
1496
1497
1498
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
1499
1500
1501
            "rel_embedding",
            self.embedding_init,
            (self.num_attention_heads, self.num_buckets),
1502
            self.weight_dtype,
1503
1504
            axes=self.embedding_axes,
        )
1505
1506
1507
1508
1509
1510

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1511
1512
1513
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1514
1515
1516
1517
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1528

1529
1530
1531
1532
    ENCODER = "encoder"
    DECODER = "decoder"


1533
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1534
1535
1536
1537
1538
1539
1540
1541
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1542
        The hidden size of each input sample.
1543
    mlp_hidden_size: int, default = 2048
1544
        Intermediate size to which input samples are projected.
1545
    num_attention_heads: int, default = 8
1546
        Number of attention heads in the transformer layer.
1547
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1548
1549
1550
1551
1552
1553
1554
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1555
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1556
        Indicate the type of layer normalization.
1557
    layernorm_epsilon: float, default = 1e-6
1558
        A value added to the denominator of layer normalization for numerical stability.
1559
    zero_centered_gamma: bool, default = False
1560
1561
1562
1563
1564
1565
1566
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1567
    hidden_dropout: float, default = 0.1
1568
        Dropout probability for the dropout op after FC2 layer.
1569
    hidden_dropout_dims: Sequence[int], default = ()
1570
        Dimensions that will share the same dropout mask for hidden
1571
    attention_dropout: float, default = 0.1
1572
        Dropout probability for the dropout op during multi-head attention.
1573
1574
1575
1576
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1577
    dropout_rng_name: str, default = 'dropout'
1578
1579
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1580
1581
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1582
1583
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1584
1585
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1586
1587
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1588
    mlp_activations: Sequence[str], default = ('relu', )
1589
        The sequence of activation functions to apply after the first linear transformation.
1590
1591
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1592
1593
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1594
    bias_init: Initializer, default = flax.linen.initializers.zeros
1595
1596
1597
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1598
    apply_residual_connection_post_layernorm: bool, default = False
1599
        If set to True, residual connections are taken from the output
1600
1601
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1602
        If set to True, layer normalization is applied on the output side,
1603
1604
1605
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1606
1607
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1608
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1609
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1610
1611
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1612
    self_attn_mask_type: str, default = 'causal'
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1632
1633
1634
1635
1636
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1637
    enable_relative_embedding: bool, default = True
1638
        Whether to enable relative embedding as shifting of attention logits.
1639
    relative_embedding: flax.linen.Module, default = None
1640
        The module for relative embedding execution, only used when
1641
1642
1643
1644
1645
1646
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1647
1648
1649
1650
1651
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1652
    rotary_pos_emb_group_method: str, default = 'consecutive'
1653
1654
1655
1656
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1657
1658
1659
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1660
        'exclude_output_proj', 'exclude_mlp']
1661
1662
1663
1664
1665
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1666
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1667
1668
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1669
1670
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1671
1672
1673

    Optimization parameters
    -----------------------
1674
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1675
1676
1677
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
1678
    drop_path: float, default = 0.0
1679
        When > 0.0, applies stochastic depth per sample in the main
1680
1681
        path of the residual block.
    fuse_qkv_params: bool, default = True
1682
        If set to True, `TransformerLayer` module exposes a single fused
1683
1684
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1685
    transpose_batch_sequence: bool, default = False
1686
        Indicate whether the input tensors were switched axis of batch
1687
1688
1689
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1690
        Indicate whether to scale attention logits.
1691
1692
1693
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1694
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1695
1696
1697
1698
1699
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1700
    num_gqa_groups: Optional[int] = None
1701
    layernorm_type: str = "layernorm"
1702
    layernorm_epsilon: float = 1e-6
1703
    zero_centered_gamma: bool = False
1704
1705
1706
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1707
1708
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1709
    dropout_rng_name: str = "dropout"
1710
1711
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1712
    mlp_activations: Sequence[str] = ("relu",)
1713
1714
1715
1716
1717
1718
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1719
    self_attn_mask_type: str = "causal"
1720
    self_attn_bias_type: Optional[str] = None
1721
1722
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1723
1724
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1725
1726
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1727
1728
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1729
    dtype: DType = jnp.float32
1730
    weight_dtype: DType = jnp.float32
1731
1732
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1733
    transpose_batch_sequence: bool = False
1734
    enable_sequence_parallel: bool = False
1735
1736
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1737
    window_size: Optional[Tuple[int, int]] = None
1738
1739
1740

    def __post_init__(self):
        if self.mha_kernel_init is None:
1741
            self.mha_kernel_init = nn.initializers.variance_scaling(
1742
                1.0, "fan_in", "normal", dtype=self.weight_dtype
1743
            )
1744
        if self.mlp_kernel_init is None:
1745
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1746
                1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
1747
            )
zlsh80826's avatar
zlsh80826 committed
1748
1749
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1750
1751
1752
        super().__post_init__()

    @nn.compact
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1763
1764
1765
1766
1767
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1768
        inputs: jax.numpy.ndarray
1769
            Input tensor.
1770
        encoded: jax.numpy.ndarray, default = None
1771
1772
1773
1774
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1775
1776
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1777
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1778
1779
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1780
            :attr:`True` means mask out the corresponding values.
1781
        deterministic: bool, default = False
1782
            Disable dropout layers if set to True.
1783
        decode: bool, default = False
1784
1785
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1786
        max_decode_length: bool, default = None
1787
1788
1789
1790
1791
1792
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1793
        outputs: jax.numpy.ndarray
1794
            Output tensors.
1795
        """
1796
1797
1798

        inputs = inputs.astype(self.dtype)

1799
1800
1801
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1802

1803
1804
1805
1806
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1807

1808
1809
1810
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1811
1812
1813
1814
1815
1816

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1817
1818
1819
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1820
1821
1822
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1823
1824
1825
1826
1827

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1828
1829
1830
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1831
1832
1833
1834
1835
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
1836
                    weight_dtype=self.weight_dtype,
1837
1838
1839
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
1857
            mha_name = "attention"
1858
        else:
1859
            mha_name = "self_attention"
1860

1861
        inputs = with_sharding_constraint_by_logical_axes(
1862
1863
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1864

1865
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1866
1867
1868
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1869
            dtype=self.dtype,
1870
            weight_dtype=self.weight_dtype,
1871
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1872
            num_gqa_groups=self.num_gqa_groups,
1873
            transpose_batch_sequence=self.transpose_batch_sequence,
1874
            enable_sequence_parallel=self.enable_sequence_parallel,
1875
            attention_dropout=self.attention_dropout,
1876
1877
1878
1879
1880
1881
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1882
            zero_centered_gamma=self.zero_centered_gamma,
1883
1884
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1885
            attn_mask_type=self.self_attn_mask_type,
1886
            attn_bias_type=self.self_attn_bias_type,
1887
1888
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1889
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1890
1891
1892
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1893
            fuse_qkv_params=self.fuse_qkv_params,
1894
1895
1896
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1897
            name=mha_name,
1898
            window_size=self.window_size,
1899
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
1900
1901
1902
1903
1904

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1905
                assert -x_shape_len <= dims < x_shape_len
1906

1907
1908
1909
1910
1911
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1912

1913
        x = with_sharding_constraint_by_logical_axes(
1914
1915
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1916
        residual = with_sharding_constraint_by_logical_axes(
1917
1918
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1919

1920
1921
1922
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1923
1924
1925
1926
1927
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1928
1929
1930
1931
1932

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1933
1934
1935
1936
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
1937
1938
1939
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
1940

1941
            x = with_sharding_constraint_by_logical_axes(
1942
1943
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1944

1945
1946
1947
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1948
                dtype=self.dtype,
1949
                weight_dtype=self.weight_dtype,
1950
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1951
                num_gqa_groups=self.num_gqa_groups,
1952
                transpose_batch_sequence=self.transpose_batch_sequence,
1953
                enable_sequence_parallel=self.enable_sequence_parallel,
1954
                attention_dropout=self.attention_dropout,
1955
1956
1957
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1958
                zero_centered_gamma=self.zero_centered_gamma,
1959
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
1960
1961
1962
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
1963
1964
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1965
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1966
1967
1968
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1969
1970
1971
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
1972
                fuse_qkv_params=self.fuse_qkv_params,
1973
1974
1975
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1976
                name="encoder_decoder_attention",
1977
                window_size=self.window_size,
1978
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
1979
1980

            y = with_sharding_constraint_by_logical_axes(
1981
1982
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1983
            residual = with_sharding_constraint_by_logical_axes(
1984
1985
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1986

1987
            y = hidden_dropout(y, deterministic)
1988
1989
1990
1991
1992

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

1993
1994
            mlp_input = y + residual

1995
        mlp_input = with_sharding_constraint_by_logical_axes(
1996
1997
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1998

1999
2000
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2001
2002
2003
2004
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2005
            zero_centered_gamma=self.zero_centered_gamma,
2006
2007
2008
2009
2010
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2011
2012
2013
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2014
            dtype=self.dtype,
2015
            weight_dtype=self.weight_dtype,
2016
2017
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2018
            kernel_init=self.mlp_kernel_init,
2019
2020
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2021
2022
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2023
2024
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2025
2026
2027
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2028
2029
2030
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2031
            name="mlp",
2032
2033
2034
2035
2036
2037
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2038
        z = with_sharding_constraint_by_logical_axes(
2039
2040
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2041
        residual = with_sharding_constraint_by_logical_axes(
2042
2043
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2044

2045
2046
2047
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2048
2049
2050
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2051
2052
2053
        z = z + residual

        if self.output_layernorm:
2054
            z = with_sharding_constraint_by_logical_axes(
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                transpose_batch_sequence=self.transpose_batch_sequence,
                dtype=self.dtype,
2065
                weight_dtype=self.weight_dtype,
2066
2067
                name="output_layernorm",
            )(z)
2068
2069

        return z