transformer.py 90.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
18
from flax.linen.attention import combine_masks
19
20
21
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
27
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
28
from ..attention import fused_attn
29
from ..softmax import SoftmaxType
30
31
32
33
34
35
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
36
37
38
39
40

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
41
42
43
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
44
45
46
47
48
49
50
51
52
53
54
55
56
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
57
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
58
59
60
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
61
62
63
64
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
65
66
67
68

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
70
71
72
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

73
74
    Parameters
    ----------
75
    rules: Sequence[Tuple[str, Union[str, None]]]
76
77
78
79
        the base Flax logical axis rules to extend.

    Returns
    -------
80
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
81
82
83
84
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
85
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
86
87
        key = item[0]
        val = item[1]
88
89
90
91
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
92
93
94
95
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
96
97

    extended_rules = [*rules]
98
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
99
100
101
        key = item[0]
        val = item[1]
        if key in rules_map:
102
103
104
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
105
                f" rules, but {val} in TE's rules."
106
            )
107
108
109
110
111
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


112
113
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
114
115
116
117
118
119
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
120
    window_size: Optional[Tuple[int, int]] = None
121
122

    @nn.compact
123
124
125
126
127
128
129
130
131
132
133
134
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
135
        batch_dim = 1 if self.transpose_batch_sequence else 0
136
137
138
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
139
        sequence_dim = 0 if self.transpose_batch_sequence else 1
140
141
142
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
143

144
145
        input_dtype = query.dtype

146
147
148
149
150
151
152
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
153
154
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
155
156
157
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
158
        is_gqa = h_q != h_kv
159

160
        if is_gqa:
161
162
163
164
165
166
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
167
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
168
            else:
169
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
170
        else:
171
            if is_gqa:
172
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
173
            else:
174
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
175

176
        attn_weights = checkpoint_name(attn_weights, "logits")
177

178
        if is_gqa:
179
180
181
182
183
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        attn_weights = with_sharding_constraint_by_logical_axes(
184
185
            attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
        )
186
187
188
189
190

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
191
            fused_scale_factor = 1.0
192
        else:
193
194
195
196
197
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

198
        def apply_swa_mask(original_mask: Array) -> Array:
199
            """Apply the sliding window mask to a given mask"""
200
            batch = original_mask.shape[0]
201
202
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
203
204
205
206
207
208
209
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
210
211
            return new_mask

212
213
        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
214
215
216
217
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
218
                mask = None
219
            if mask is not None:
220
                mask = apply_swa_mask(mask)
221
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
222
223
224
            if mask is not None:
                return SoftmaxType.SCALED_MASKED, mask
            if attn_mask_type is AttnMaskType.CAUSAL_MASK:
225
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
226
            if attn_mask_type is AttnMaskType.NO_MASK:
227
                return SoftmaxType.SCALED, mask
228
229
230
231
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
232

233
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
234

235
236
        attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
            attn_weights, mask, bias
237
        ).astype(input_dtype)
238

239
        if is_gqa:
240
241
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

242
        if not deterministic and self.attention_dropout > 0.0:
243
244
245
246
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
247
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
248
249
            attn_weights = attn_weights * multiplier

250
251
252
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
253
254
        if self.transpose_batch_sequence:
            if is_gqa:
255
256
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
257
258

        if is_gqa:
259
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
260

261
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
262
263


264
265
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
266
267
268
269
270
271
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
272
    window_size: Optional[Tuple[int, int]] = None
273
    max_segments_per_seq: Optional[int] = 1
274
275
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
276
277

    @nn.compact
278
279
280
281
282
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
283
        sequence_descriptor: Optional[SequenceDescriptor] = None,
284
285
286
287
288
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
289
290
291
292
293
294
295
296
297
298
299

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

300
        if self.qkv_layout.is_qkvpacked():
301
302
303
304
305
306
307
308
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
309
310
            x = fused_attn(
                (qkv_packed,),
311
                bias,
312
                sequence_descriptor,
313
314
315
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
316
                qkv_layout=self.qkv_layout,
317
318
319
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
320
                window_size=self.window_size,
321
                max_segments_per_seq=self.max_segments_per_seq,
322
323
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
324
            )
325
        elif self.qkv_layout.is_kvpacked():
326
327
328
329
330
331
332
333
334
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
335
336
            x = fused_attn(
                (query, kv_packed),
337
                bias,
338
                sequence_descriptor,
339
340
341
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
342
                qkv_layout=self.qkv_layout,
343
344
345
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
346
                window_size=self.window_size,
347
                max_segments_per_seq=self.max_segments_per_seq,
348
349
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
350
            )
351
        elif self.qkv_layout.is_separate():
352
353
354
355
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
356
            x = fused_attn(
357
                (query, key, value),
358
                bias,
359
                sequence_descriptor,
360
361
362
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
363
                qkv_layout=self.qkv_layout,
364
365
366
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
367
                window_size=self.window_size,
368
                max_segments_per_seq=self.max_segments_per_seq,
369
370
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
371
            )
372
373
374
375
376
377
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

378
        assert x.dtype == query.dtype
379
380
381
        return x


382
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

403
404
405
406
407
408
409
410
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

447
448
        .. note:: THD format only supports 'padding' or 'causal_padding' mask type.

449
450
451
452
453
454
455
456
       attn_mask_type       mask/sequence_descriptor       SWA          softmax type
       --------------------------------------------------------------------------------------------
       no_mask              None                           None         SCALED
       causal               None                           None         SCALED_UPPER_TRIANG_MASKED
       causal               None                           Yes          SCALED_MASKED
       padding              Required                       Yes/No       SCALED_MASKED
       padding_causal       Required                       Yes/No       SCALED_MASKED

457
    attn_bias_type: Optional[str], default = None
458
        Type of the attention bias passed in the attention.
459
460
461
462
463
464
465
466
467
468
469
470
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
471
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
472
473
474
475
476
477

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
478
479
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
496
497
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
498
499
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
500
501
502
    context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis (str): The name of the context parallel axis.
503
504
505

    Optimization parameters
    -----------------------
506
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
507
        The data type used to allocate the initial parameters.
508
    """
509

510
511
512
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
513
514
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
515
516
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
517
    dropout_rng_name: str = "dropout"
518
    float32_logits: bool = False
519
    qkv_layout: str = "bshd_bshd_bshd"
520
521
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
522
    window_size: Optional[Tuple[int, int]] = None
523
    max_segments_per_seq: Optional[int] = 1
524
525
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
526
527

    @nn.compact
528
529
530
531
532
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
533
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
534
535
536
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
537
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
538
    ) -> Array:
539
540
541
542
543
544
545
546
547
548
549
550
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
551
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
552
553
554
555
556
557
558
559
560
561
562
563
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
564
        input_dtype = query.dtype
565

566
567
568
569
570
571
572
573
574
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]

598
        has_fused_attn_kernel = is_fused_attn_kernel_available(
599
600
            # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
            not deterministic,
601
602
603
604
605
606
607
608
609
610
611
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
            self.head_dim,
612
            self.window_size,
613
        )
614

615
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
616
617

        if enable_fused_attn and not has_fused_attn_kernel:
618
619
620
621
622
623
624
625
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n"
            )
626
627

        dropout_rng = None
628
        if not deterministic and self.attention_dropout > 0.0:
629
630
631
632
633
634
635
636
637
638
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(self.head_dim)
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
639
            if qkv_layout.is_qkvpacked():
640
                query, key, value = jnp.split(query, [1, 2], axis=-3)
641
642
643
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
644
            elif qkv_layout.is_kvpacked():
645
646
647
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
648
649
                assert qkv_layout.is_separate()

650
651
652
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
653

654
655
656
657
658
659
660
661
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
662
                window_size=self.window_size,
663
664
665
666
667
668
669
670
671
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
672
673
674
675
676
677
678
679
680
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
681
                window_size=self.window_size,
682
                max_segments_per_seq=self.max_segments_per_seq,
683
684
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
685
686
687
688
689
690
691
692
693
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
694
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
695
        return x
696
697


698
699
700
701
702
703
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
704
705
706
    """
    Rotary Positional Embedding
    x should be in shape of
707
708
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
709
    """
710
711
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
712
713
714
    min_window = windows[0]
    max_window = windows[1]

715
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
716
    time_scales = min_window * (max_window / min_window) ** fraction
717
718
719
720
721
722
723
724
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

725
726
727
728
729
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
730

731
732
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
733

734
        x1, x2 = jnp.split(x, 2, axis=-1)
735
736
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
737

738
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
762
763
764
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
765
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
766
        )
767
768
769
770
771

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

772
    if group_method == "alternate":
773
774
        return alternate_impl()
    return consecutive_impl()
775
776


777
class LoRAScope:  # pylint: disable=too-few-public-methods
778
779
780
781
782
783
784
785
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
786
787
788
789
790
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
791
792
793
794


def _canonicalize_lora_scope(scope):

795
796
797
798
799
800
801
802
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
803
804
805
806
807
808

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
809
810
811
812
813
814
815
816
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


833
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
834
835
836
837
838
839
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
840
    head_dim: int
841
        The hidden dimension of each attention head.
842
843
844
845
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
846
847
848
849
850
851
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
852
853
854
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

874
875
876
877
878
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
879
    dropout_rng_name: str, default = 'dropout'
880
881
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
882
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
883
        Indicate the type of layer normalization.
884
    layernorm_epsilon: float, default = 1e-6
885
        A value added to the denominator of layer normalization for numerical stability.
886
    zero_centered_gamma: bool, default = False
887
888
889
890
891
892
893
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
894
895
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
896
        Used for initializing the QKV and output projection weights.
897
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
898
    use_bias: bool, default = False
899
        Indicate whether or not to enable bias shifting for QKV and output projections.
900
        If set to False, the layer will not learn additive biases.
901
    bias_init: Initializer, default = flax.linen.initializers.zeros
902
903
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
904
905
906
907
908
909
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
910
911
912
913
914
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
915
916
917
918
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
919
920
921
922
923
924
925
926
927
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
928
929
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
930
931
932
933
934
935
936
937
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
938
939
940

    Optimization parameters
    -----------------------
941
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
942
        The data type used to allocate the initial parameters.
943
    fuse_qkv_params: bool, default = True
944
        If set to True, this module exposes a single fused
945
946
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
947
    transpose_batch_sequence: bool, default = True
948
        Indicate whether the input tensors were switched axis of batch
949
950
951
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
952
        Indicate whether to scale attention logits.
953
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
954
        else :math:`Q*K`
955
956
957
958
959
960
961
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
962
963
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
964
965
966
    """

    head_dim: int
967
968
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
969
970
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
971
    input_layernorm: bool = True
972
973
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
974
    return_layernorm_output: bool = False
975
    zero_centered_gamma: bool = False
976
977
978
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
979
    attn_mask_type: str = "causal"
980
    attn_bias_type: Optional[str] = None
981
982
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
983
984
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
985
986
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
987
    dtype: DType = jnp.float32
988
    fuse_qkv_params: bool = True
989
    transpose_batch_sequence: bool = True
990
    enable_sequence_parallel: bool = False
991
992
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
993
    float32_logits: bool = False
994
    window_size: Optional[Tuple[int, int]] = None
995
996
997
998
999
1000
1001

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
1002
1003

    def __post_init__(self):
1004
1005
1006
1007
1008
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1009
1010
1011
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1012
1013
1014
1015
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1016
1017
1018
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1019
1020
1021
1022
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1023
1024
                DeprecationWarning,
            )
1025
1026
1027
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1028
1029
1030
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1031
1032
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1033
1034
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1035

1036
        if self.kernel_init is None:
1037
            self.kernel_init = nn.initializers.variance_scaling(
1038
                1.0, "fan_in", "normal", dtype=self.dtype
1039
            )
zlsh80826's avatar
zlsh80826 committed
1040
        if self.num_gqa_groups is None:
1041
            self.num_gqa_groups = self.num_attention_heads
1042
1043
1044
        super().__post_init__()

    @nn.compact
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1055
1056
1057
1058
1059
1060
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1061
        inputs_q: jax.numpy.ndarray
1062
            Input tensor for query projection.
1063
        inputs_kv: jax.numpy.ndarray
1064
            Input tensor for key/value projection.
1065
1066
1067
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
1068
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
1069
1070
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1071
        *
1072
        decode: bool, default = False
1073
            Indicate whether to prepare and use an autoregressive cache.
1074
        deterministic: bool, default = False
1075
1076
1077
1078
            Disable dropout layers if set to True.

        Returns
        -------
1079
        outputs: jax.numpy.ndarray
1080
1081
            Output tensors.
        """
1082

1083
1084
1085
1086
1087
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1088
        def query_init(*args):
1089
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1132
1133
1134
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1135

1136
1137
1138
1139
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1140
1141
1142
1143
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1144
1145
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1146
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1147
            if is_qkvpack:
1148
                qkv_proj, ln_out = LayerNormDenseGeneral(
1149
                    enable_layernorm=self.input_layernorm,
1150
                    layernorm_type=self.layernorm_type,
1151
                    zero_centered_gamma=self.zero_centered_gamma,
1152
1153
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1154
                    features=(3, self.num_attention_heads * self.head_dim),
1155
                    transpose_batch_sequence=self.transpose_batch_sequence,
1156
                    return_layernorm_output=self.return_layernorm_output,
1157
1158
1159
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1160
1161
1162
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1163
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1164
1165
1166
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1167
1168
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1169
1170
1171
1172
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1173
                qkv_layout = QKVLayout.BS3HD
1174
1175
            else:
                query, ln_out = LayerNormDenseGeneral(
1176
                    enable_layernorm=self.input_layernorm,
1177
                    layernorm_type=self.layernorm_type,
1178
                    zero_centered_gamma=self.zero_centered_gamma,
1179
1180
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1181
                    features=self.num_attention_heads * self.head_dim,
1182
                    transpose_batch_sequence=self.transpose_batch_sequence,
1183
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1184
1185
1186
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1187
1188
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1189
                    bias_axes=(W_TP_AXES,),
1190
1191
1192
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1193
1194
                    dtype=self.dtype,
                    kernel_init=query_init,
1195
1196
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1197
1198
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1199
1200
1201
1202
1203

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1220
                qkv_layout = QKVLayout.BSHD_BS2HD
1221
1222
1223
1224
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1225
                features=self.num_gqa_groups * self.head_dim,
1226
                transpose_batch_sequence=self.transpose_batch_sequence,
1227
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1228
1229
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1230
                bias_axes=(W_TP_AXES,),
1231
1232
1233
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1234
1235
                dtype=self.dtype,
            )
1236
            query, ln_out = LayerNormDenseGeneral(
1237
                enable_layernorm=self.input_layernorm,
1238
                layernorm_type=self.layernorm_type,
1239
                zero_centered_gamma=self.zero_centered_gamma,
1240
1241
                epsilon=self.layernorm_epsilon,
                axis=-1,
1242
                features=self.num_attention_heads * self.head_dim,
1243
1244
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
1245
1246
1247
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1248
1249
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1250
                bias_axes=(W_TP_AXES,),
1251
1252
1253
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1254
1255
                dtype=self.dtype,
                kernel_init=query_init,
1256
1257
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1258
1259
                name="query",
            )(inputs_q)
1260

1261
            if is_self_attn:
1262
1263
1264
                assert ln_out is not None
                inputs_kv = ln_out

1265
            query = query.astype(input_dtype)
1266
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1267
            key = key.astype(input_dtype)
1268
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1269
            value = value.astype(input_dtype)
1270
1271
1272
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1273
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1274

1275
        if self.enable_rotary_pos_emb:
1276
1277
1278
1279
1280
1281
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1282

1283
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1284
1285
1286
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1299
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1300

1301
1302
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1303
1304
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1305
1306

        if decode:
1307
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1308
1309
1310
1311
1312
1313
1314
1315
1316
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1317
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1318
                if self.transpose_batch_sequence:
1319
1320
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1321
1322
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1323
1324
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1325
                    one_hot_indices_shape = (1, length, 1, 1)
1326
1327
1328
1329

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1330
1331
1332
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1333

1334
                cur_index = cache_index.value.astype(jnp.int32)
1335
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1336
1337
1338
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1339
1340
1341
1342
1343
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1344
1345
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1346
1347

                if bias is not None:
1348
1349
1350
1351
1352
1353
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1354

1355
1356
1357
1358
1359
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1360
1361
1362
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1374
        else:
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1385
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1399
            window_size=self.window_size,
1400
        )(*dpa_args, mask, bias, deterministic=deterministic)
1401
1402
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1403
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1404
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1405

1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            transpose_batch_sequence=self.transpose_batch_sequence,
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1422

1423
1424
1425
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1426
        return out, ln_out
1427
1428


1429
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1430
1431
1432
1433
1434
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1435
    num_buckets: int
1436
        The number of buckets to bucket distances between key and query positions into.
1437
    max_distance: int
1438
        The maximum distance before everything is lumped into the last
1439
        distance bucket.
1440
    num_attention_heads: int
1441
        Number of attention heads in the transformer layer.
1442
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1443
        Used for initializing relative embedding tables.
1444
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1445
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1446
1447
1448

    Optimization parameters
    -----------------------
1449
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1450
        The data type used to allocate the initial parameters.
1451
    """
1452

1453
1454
1455
1456
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1457
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1458
1459
1460
1461
1462
1463
1464
1465
1466
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1467
        q_seqlen: int
1468
            The sequence length of query.
1469
        k_seqlen: int
1470
            The sequence length of key.
1471
        bidirectional: bool, default = True
1472
            Indicate whether to allow positive memory-query relative position
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1499
1500
1501
1502
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1503
1504
1505
1506
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
1507
        relative_attention_bias = self.param(
1508
            "rel_embedding",
1509
            nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
1510
            (self.num_attention_heads, self.num_buckets),
1511
            self.dtype,
1512
        )
1513
1514
1515
1516
1517
1518

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1519
1520
1521
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1522
1523
1524
1525
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1536

1537
1538
1539
1540
    ENCODER = "encoder"
    DECODER = "decoder"


1541
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1542
1543
1544
1545
1546
1547
1548
1549
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1550
        The hidden size of each input sample.
1551
    mlp_hidden_size: int, default = 2048
1552
        Intermediate size to which input samples are projected.
1553
    num_attention_heads: int, default = 8
1554
        Number of attention heads in the transformer layer.
1555
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1556
1557
1558
1559
1560
1561
1562
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1563
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1564
        Indicate the type of layer normalization.
1565
    layernorm_epsilon: float, default = 1e-6
1566
        A value added to the denominator of layer normalization for numerical stability.
1567
    zero_centered_gamma: bool, default = False
1568
1569
1570
1571
1572
1573
1574
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1575
    hidden_dropout: float, default = 0.1
1576
        Dropout probability for the dropout op after FC2 layer.
1577
    hidden_dropout_dims: Sequence[int], default = ()
1578
        Dimensions that will share the same dropout mask for hidden
1579
    attention_dropout: float, default = 0.1
1580
        Dropout probability for the dropout op during multi-head attention.
1581
1582
1583
1584
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1585
    dropout_rng_name: str, default = 'dropout'
1586
1587
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1588
1589
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1590
1591
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1592
1593
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1594
1595
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1596
    mlp_activations: Sequence[str], default = ('relu', )
1597
        The sequence of activation functions to apply after the first linear transformation.
1598
1599
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1600
1601
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1602
    bias_init: Initializer, default = flax.linen.initializers.zeros
1603
1604
1605
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1606
    apply_residual_connection_post_layernorm: bool, default = False
1607
        If set to True, residual connections are taken from the output
1608
1609
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1610
        If set to True, layer normalization is applied on the output side,
1611
1612
1613
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1614
1615
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1616
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1617
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1618
1619
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1620
    self_attn_mask_type: str, default = 'causal'
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1640
1641
1642
1643
1644
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1645
    enable_relative_embedding: bool, default = True
1646
        Whether to enable relative embedding as shifting of attention logits.
1647
    relative_embedding: flax.linen.Module, default = None
1648
        The module for relative embedding execution, only used when
1649
1650
1651
1652
1653
1654
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1655
1656
1657
1658
1659
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1660
    rotary_pos_emb_group_method: str, default = 'consecutive'
1661
1662
1663
1664
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1665
1666
1667
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1668
        'exclude_output_proj', 'exclude_mlp']
1669
1670
1671
1672
1673
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1674
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1675
1676
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1677
1678
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1679
1680
1681

    Optimization parameters
    -----------------------
1682
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1683
        The data type used to allocate the initial parameters.
1684
    drop_path: float, default = 0.0
1685
        When > 0.0, applies stochastic depth per sample in the main
1686
1687
        path of the residual block.
    fuse_qkv_params: bool, default = True
1688
        If set to True, `TransformerLayer` module exposes a single fused
1689
1690
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1691
    transpose_batch_sequence: bool, default = False
1692
        Indicate whether the input tensors were switched axis of batch
1693
1694
1695
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1696
        Indicate whether to scale attention logits.
1697
1698
1699
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1700
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1701
1702
1703
1704
1705
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1706
    num_gqa_groups: Optional[int] = None
1707
    layernorm_type: str = "layernorm"
1708
    layernorm_epsilon: float = 1e-6
1709
    zero_centered_gamma: bool = False
1710
1711
1712
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1713
1714
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1715
    dropout_rng_name: str = "dropout"
1716
1717
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1718
    mlp_activations: Sequence[str] = ("relu",)
1719
1720
1721
1722
1723
1724
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1725
    self_attn_mask_type: str = "causal"
1726
    self_attn_bias_type: Optional[str] = None
1727
1728
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1729
1730
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1731
1732
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1733
1734
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1735
1736
1737
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1738
    transpose_batch_sequence: bool = False
1739
    enable_sequence_parallel: bool = False
1740
1741
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1742
    window_size: Optional[Tuple[int, int]] = None
1743
1744
1745

    def __post_init__(self):
        if self.mha_kernel_init is None:
1746
            self.mha_kernel_init = nn.initializers.variance_scaling(
1747
                1.0, "fan_in", "normal", dtype=self.dtype
1748
            )
1749
        if self.mlp_kernel_init is None:
1750
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1751
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1752
            )
zlsh80826's avatar
zlsh80826 committed
1753
1754
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1755
1756
1757
        super().__post_init__()

    @nn.compact
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1768
1769
1770
1771
1772
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1773
        inputs: jax.numpy.ndarray
1774
            Input tensor.
1775
        encoded: jax.numpy.ndarray, default = None
1776
1777
1778
1779
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1780
1781
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1782
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1783
1784
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1785
            :attr:`True` means mask out the corresponding values.
1786
        deterministic: bool, default = False
1787
            Disable dropout layers if set to True.
1788
        decode: bool, default = False
1789
1790
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1791
        max_decode_length: bool, default = None
1792
1793
1794
1795
1796
1797
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1798
        outputs: jax.numpy.ndarray
1799
            Output tensors.
1800
        """
1801

1802
        input_dtype = inputs.dtype
1803
1804
1805
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1806

1807
1808
1809
1810
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1811

1812
1813
1814
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1815
1816
1817
1818
1819
1820

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1821
1822
1823
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1824
1825
1826
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1827
1828
1829
1830
1831

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1832
1833
1834
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1835
1836
1837
1838
1839
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
1840
1841
1842
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
1843
1844
                    name="relpos_bias",
                )
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
1862
            mha_name = "attention"
1863
        else:
1864
            mha_name = "self_attention"
1865

1866
        inputs = with_sharding_constraint_by_logical_axes(
1867
1868
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1869

1870
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1871
1872
1873
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1874
1875
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1876
            num_gqa_groups=self.num_gqa_groups,
1877
            transpose_batch_sequence=self.transpose_batch_sequence,
1878
            enable_sequence_parallel=self.enable_sequence_parallel,
1879
            attention_dropout=self.attention_dropout,
1880
1881
1882
1883
1884
1885
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1886
            zero_centered_gamma=self.zero_centered_gamma,
1887
1888
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1889
            attn_mask_type=self.self_attn_mask_type,
1890
            attn_bias_type=self.self_attn_bias_type,
1891
1892
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1893
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1894
1895
1896
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1897
            fuse_qkv_params=self.fuse_qkv_params,
1898
1899
1900
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1901
            name=mha_name,
1902
            window_size=self.window_size,
1903
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
1904
1905
1906
1907
1908

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1909
                assert -x_shape_len <= dims < x_shape_len
1910

1911
1912
1913
1914
1915
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1916

1917
        x = with_sharding_constraint_by_logical_axes(
1918
1919
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1920
        residual = with_sharding_constraint_by_logical_axes(
1921
1922
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1923

1924
1925
1926
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1927
1928
1929
1930
1931
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1932
1933
1934
1935
1936

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1937
1938
1939
1940
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
1941
1942
1943
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
1944

1945
            x = with_sharding_constraint_by_logical_axes(
1946
1947
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1948

1949
1950
1951
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1952
1953
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1954
                num_gqa_groups=self.num_gqa_groups,
1955
                transpose_batch_sequence=self.transpose_batch_sequence,
1956
                enable_sequence_parallel=self.enable_sequence_parallel,
1957
                attention_dropout=self.attention_dropout,
1958
1959
1960
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1961
                zero_centered_gamma=self.zero_centered_gamma,
1962
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
1963
1964
1965
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
1966
1967
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1968
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1969
1970
1971
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1972
1973
1974
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
1975
                fuse_qkv_params=self.fuse_qkv_params,
1976
1977
1978
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1979
                name="encoder_decoder_attention",
1980
                window_size=self.window_size,
1981
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
1982
1983

            y = with_sharding_constraint_by_logical_axes(
1984
1985
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1986
            residual = with_sharding_constraint_by_logical_axes(
1987
1988
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1989

1990
            y = hidden_dropout(y, deterministic)
1991
1992
1993
1994
1995

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

1996
1997
            mlp_input = y + residual

1998
        mlp_input = with_sharding_constraint_by_logical_axes(
1999
2000
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2001

2002
2003
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2004
2005
2006
2007
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2008
            zero_centered_gamma=self.zero_centered_gamma,
2009
2010
2011
2012
2013
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2014
2015
2016
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2017
            dtype=self.dtype,
2018
2019
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2020
            kernel_init=self.mlp_kernel_init,
2021
2022
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2023
2024
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2025
2026
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2027
2028
2029
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2030
2031
2032
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2033
            name="mlp",
2034
2035
2036
2037
2038
2039
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2040
        z = with_sharding_constraint_by_logical_axes(
2041
2042
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2043
        residual = with_sharding_constraint_by_logical_axes(
2044
2045
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2046

2047
2048
2049
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2050
2051
2052
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2053
2054
2055
        z = z + residual

        if self.output_layernorm:
2056
            z = with_sharding_constraint_by_logical_axes(
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                transpose_batch_sequence=self.transpose_batch_sequence,
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2069
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2070
        return z