transformer.py 93.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
18
from flax.linen.attention import combine_masks
19
20
21
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
27
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
28
from ..attention import fused_attn
29
from ..attention import CPStrategy
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
42
43
44
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
45
46
47
48
49
50
51
52
53
54
55
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


Phuong Nguyen's avatar
Phuong Nguyen committed
56
# TODO(Phuong): move this function to sharding.py
57
58
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
59
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
60
61
62
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
63
64
65
66
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
67
68

    .. warning::
69
        Please make sure ShardingResource is set via autocast before calling this function.
70

Ming-Xu Huang's avatar
Ming-Xu Huang committed
71
72
73
74
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

75
76
    Parameters
    ----------
77
    rules: Sequence[Tuple[str, Union[str, None]]]
78
79
80
81
        the base Flax logical axis rules to extend.

    Returns
    -------
82
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
83
84
85
86
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
87
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
88
89
        key = item[0]
        val = item[1]
90
91
92
93
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
94
95
96
97
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
98
99

    extended_rules = [*rules]
100
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
101
102
103
        key = item[0]
        val = item[1]
        if key in rules_map:
104
105
106
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
107
                f" rules, but {val} in TE's rules."
108
            )
109
110
111
112
113
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


114
115
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
116
117
118
119
120
121
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
122
    window_size: Optional[Tuple[int, int]] = None
123
124

    @nn.compact
125
126
127
128
129
130
131
132
133
134
135
136
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
137
        batch_dim = 1 if self.transpose_batch_sequence else 0
138
139
140
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
141
        sequence_dim = 0 if self.transpose_batch_sequence else 1
142
143
144
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
145

146
147
        input_dtype = query.dtype

148
149
150
151
152
153
154
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
155
156
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
157
158
159
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
160
        is_gqa = h_q != h_kv
161

162
        if is_gqa:
163
164
165
166
167
168
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
169
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
170
            else:
171
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
172
        else:
173
            if is_gqa:
174
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
175
            else:
176
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
177

178
        attn_weights = checkpoint_name(attn_weights, "logits")
179

180
        if is_gqa:
181
182
183
184
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

185
        # (b, h, q, k): Last two axes are always replicated
186
        attn_weights = with_sharding_constraint_by_logical_axes(
187
            attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
188
        )
189
190
191
192
193

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
194
            fused_scale_factor = 1.0
195
        else:
196
197
198
199
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias
200
                bias = None
201

202
        def apply_swa_mask(original_mask: Array) -> Array:
203
            """Apply the sliding window mask to a given mask"""
204
            batch = original_mask.shape[0]
205
206
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
207
208
209
210
211
212
213
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
214
215
            return new_mask

216
217
        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
218
219
220
221
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
222
                mask = None
223
            if mask is not None:
224
                mask = apply_swa_mask(mask)
225
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
226
227
228
            if mask is not None:
                return SoftmaxType.SCALED_MASKED, mask
            if attn_mask_type is AttnMaskType.CAUSAL_MASK:
229
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
230
            if attn_mask_type is AttnMaskType.NO_MASK:
231
                return SoftmaxType.SCALED, mask
232
233
234
235
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
236

237
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
238

239
240
        attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
            attn_weights, mask, bias
241
        ).astype(input_dtype)
242

243
        if is_gqa:
244
245
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

246
        if not deterministic and self.attention_dropout > 0.0:
247
248
249
250
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
251
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
252
253
            attn_weights = attn_weights * multiplier

254
255
256
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
257
258
        if self.transpose_batch_sequence:
            if is_gqa:
259
260
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
261
262

        if is_gqa:
263
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
264

265
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
266
267


268
269
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
270
271
272
273
274
275
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
276
    window_size: Optional[Tuple[int, int]] = None
277
    max_segments_per_seq: Optional[int] = 1
278
279
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
280
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
281
    context_checkpoint_name: str = "context"
282
283

    @nn.compact
284
285
286
287
288
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
289
        sequence_descriptor: Optional[SequenceDescriptor] = None,
290
291
292
293
294
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
295
296
297
298
299
300
301
302
303
304
305

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

306
        if self.qkv_layout.is_qkvpacked():
307
308
309
310
311
312
313
314
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
315
316
            x = fused_attn(
                (qkv_packed,),
317
                bias,
318
                sequence_descriptor,
319
320
321
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
322
                qkv_layout=self.qkv_layout,
323
324
325
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
326
                window_size=self.window_size,
327
                max_segments_per_seq=self.max_segments_per_seq,
328
329
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
330
                context_parallel_strategy=self.context_parallel_strategy,
331
                context_checkpoint_name=self.context_checkpoint_name,
332
            )
333
        elif self.qkv_layout.is_kvpacked():
334
335
336
337
338
339
340
341
342
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
343
344
            x = fused_attn(
                (query, kv_packed),
345
                bias,
346
                sequence_descriptor,
347
348
349
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
350
                qkv_layout=self.qkv_layout,
351
352
353
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
354
                window_size=self.window_size,
355
                max_segments_per_seq=self.max_segments_per_seq,
356
357
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
358
                context_parallel_strategy=self.context_parallel_strategy,
359
                context_checkpoint_name=self.context_checkpoint_name,
360
            )
361
        elif self.qkv_layout.is_separate():
362
363
364
365
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
366
            x = fused_attn(
367
                (query, key, value),
368
                bias,
369
                sequence_descriptor,
370
371
372
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
373
                qkv_layout=self.qkv_layout,
374
375
376
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
377
                window_size=self.window_size,
378
                max_segments_per_seq=self.max_segments_per_seq,
379
380
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
381
                context_parallel_strategy=self.context_parallel_strategy,
382
                context_checkpoint_name=self.context_checkpoint_name,
383
            )
384
385
386
387
388
389
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

390
        assert x.dtype == query.dtype
391
392
393
        return x


394
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

410
411
412
413
        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
          attention kernel is not available on the system, a warning will be issued, and the module
          will automatically fall back to the unfused backend.
414

415
416
417
418
419
420
421
422
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

459
460
        .. note:: THD format only supports 'padding' or 'causal_padding' mask type.

461
462
463
464
465
466
467
468
       attn_mask_type       mask/sequence_descriptor       SWA          softmax type
       --------------------------------------------------------------------------------------------
       no_mask              None                           None         SCALED
       causal               None                           None         SCALED_UPPER_TRIANG_MASKED
       causal               None                           Yes          SCALED_MASKED
       padding              Required                       Yes/No       SCALED_MASKED
       padding_causal       Required                       Yes/No       SCALED_MASKED

469
    attn_bias_type: Optional[str], default = None
470
        Type of the attention bias passed in the attention.
471
472
473
474
475
476
477
478
479
480
481
482
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
483
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
484
485
486
487
488
489

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
490
491
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
508
509
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
510
511
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
512
513
514
    context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis (str): The name of the context parallel axis.
515
    context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
516
    context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
517
518
519

    Optimization parameters
    -----------------------
520
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
521
        The data type used to allocate the initial parameters.
522
    """
523

524
525
526
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
527
528
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
529
530
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
531
    dropout_rng_name: str = "dropout"
532
    float32_logits: bool = False
533
    qkv_layout: str = "bshd_bshd_bshd"
534
535
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
536
    window_size: Optional[Tuple[int, int]] = None
537
    max_segments_per_seq: Optional[int] = 1
538
539
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
540
    context_parallel_strategy: str = "DEFAULT"
541
    context_checkpoint_name: str = "context"
542
543

    @nn.compact
544
545
546
547
548
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
549
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
550
551
552
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
553
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
554
    ) -> Array:
555
556
557
558
559
560
561
562
563
564
565
566
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
567
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
568
569
570
571
572
573
574
575
576
577
578
579
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
580
        input_dtype = query.dtype
581

582
583
584
585
586
587
588
589
590
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

591
592
593
594
595
596
597
598
599
600
601
602
603
604
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

605
606
        # Use fused attn (if kernel check below passes) by default
        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
607
608
609
610
611
612
613

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]
614
615
616
617
618
619
        if qkv_layout.is_separate():
            head_dim_qk = query.shape[-1]
            head_dim_v = value.shape[-1]
        else:
            head_dim_qk = self.head_dim
            head_dim_v = self.head_dim
620

621
        has_fused_attn_kernel = is_fused_attn_kernel_available(
622
623
            # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
            not deterministic,
624
625
626
627
628
629
630
631
632
633
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
634
635
            head_dim_qk,
            head_dim_v,
636
            self.window_size,
637
        )
638

639
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
640
641

        if enable_fused_attn and not has_fused_attn_kernel:
642
643
644
645
646
647
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
648
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
649
            )
650
651

        dropout_rng = None
652
        if not deterministic and self.attention_dropout > 0.0:
653
654
655
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
656
            scale_factor = 1.0 / sqrt(head_dim_qk)
657
658
659
660
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
        # case-insensitive mapping for context parallel strategy
        cp_strategy_map = {
            "DEFAULT": CPStrategy.DEFAULT,
            "ALL_GATHER": CPStrategy.ALL_GATHER,
            "ALLGATHER": CPStrategy.ALL_GATHER,  # Alternative spelling
            "RING": CPStrategy.RING,
        }

        strategy_key = self.context_parallel_strategy.upper()
        if strategy_key in cp_strategy_map:
            context_parallel_strategy = cp_strategy_map[strategy_key]
        else:
            valid_strategies = list(cp_strategy_map.keys())
            raise ValueError(
                f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
                f"Valid options are: {valid_strategies} (case insensitive)"
            )

679
680
        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
681
            if qkv_layout.is_qkvpacked():
682
                query, key, value = jnp.split(query, [1, 2], axis=-3)
683
684
685
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
686
            elif qkv_layout.is_kvpacked():
687
688
689
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
690
691
                assert qkv_layout.is_separate()

692
693
694
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
695

696
697
698
699
700
701
702
703
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
704
                window_size=self.window_size,
705
706
707
708
709
710
711
712
713
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
714
715
716
717
718
719
720
721
722
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
723
                window_size=self.window_size,
724
                max_segments_per_seq=self.max_segments_per_seq,
725
726
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
727
                context_parallel_strategy=context_parallel_strategy,
728
                context_checkpoint_name=self.context_checkpoint_name,
729
730
731
732
733
734
735
736
737
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
738
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
739
        return x
740
741


742
743
744
745
746
747
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
748
749
750
    """
    Rotary Positional Embedding
    x should be in shape of
751
752
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
753
    """
754
755
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
756
757
758
    min_window = windows[0]
    max_window = windows[1]

759
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
760
    time_scales = min_window * (max_window / min_window) ** fraction
761
762
763
764
765
766
767
768
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

769
770
771
772
773
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
774

775
776
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
777

778
        x1, x2 = jnp.split(x, 2, axis=-1)
779
780
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
781

782
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
806
807
808
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
809
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
810
        )
811
812
813
814
815

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

816
    if group_method == "alternate":
817
818
        return alternate_impl()
    return consecutive_impl()
819
820


821
class LoRAScope:  # pylint: disable=too-few-public-methods
822
823
824
825
826
827
828
829
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
830
831
832
833
834
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
835
836
837
838


def _canonicalize_lora_scope(scope):

839
840
841
842
843
844
845
846
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
847
848
849
850
851
852

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
853
854
855
856
857
858
859
860
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


877
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
878
879
880
881
882
883
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
884
    head_dim: int
885
        The hidden dimension of each attention head.
886
887
888
889
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
890
891
892
893
894
895
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
896
897
898
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

918
919
920
921
922
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
923
    dropout_rng_name: str, default = 'dropout'
924
925
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
926
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
927
        Indicate the type of layer normalization.
928
    layernorm_epsilon: float, default = 1e-6
929
        A value added to the denominator of layer normalization for numerical stability.
930
    zero_centered_gamma: bool, default = False
931
932
933
934
935
936
937
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
938
939
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
940
        Used for initializing the QKV and output projection weights.
941
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
942
    use_bias: bool, default = False
943
        Indicate whether or not to enable bias shifting for QKV and output projections.
944
        If set to False, the layer will not learn additive biases.
945
    bias_init: Initializer, default = flax.linen.initializers.zeros
946
947
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
948
949
950
951
952
953
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
954
955
956
957
958
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
959
960
961
962
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
963
964
965
966
967
968
969
970
971
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
972
973
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
974
975
976
977
978
979
980
981
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
982
983
984

    Optimization parameters
    -----------------------
985
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
986
        The data type used to allocate the initial parameters.
987
    fuse_qkv_params: bool, default = True
988
        If set to True, this module exposes a single fused
989
990
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
991
    transpose_batch_sequence: bool, default = True
992
        Indicate whether the input tensors were switched axis of batch
993
994
995
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
996
        Indicate whether to scale attention logits.
997
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
998
        else :math:`Q*K`
999
1000
1001
1002
1003
1004
1005
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
1006
1007
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1008
1009
1010
    """

    head_dim: int
1011
1012
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
1013
1014
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
1015
    input_layernorm: bool = True
1016
1017
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
1018
    return_layernorm_output: bool = False
1019
    zero_centered_gamma: bool = False
1020
1021
1022
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
1023
    attn_mask_type: str = "causal"
1024
    attn_bias_type: Optional[str] = None
1025
1026
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1027
1028
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1029
1030
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1031
    dtype: DType = jnp.float32
1032
    fuse_qkv_params: bool = True
1033
    transpose_batch_sequence: bool = True
1034
    enable_sequence_parallel: bool = False
1035
1036
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1037
    float32_logits: bool = False
1038
    window_size: Optional[Tuple[int, int]] = None
1039
1040
1041
1042
1043
1044
1045

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
1046
1047

    def __post_init__(self):
1048
1049
1050
1051
1052
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1053
1054
1055
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1056
1057
1058
1059
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1060
1061
1062
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1063
1064
1065
1066
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1067
1068
                DeprecationWarning,
            )
1069
1070
1071
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1072
1073
1074
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1075
1076
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1077
1078
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1079

1080
        if self.kernel_init is None:
1081
            self.kernel_init = nn.initializers.variance_scaling(
1082
                1.0, "fan_in", "normal", dtype=self.dtype
1083
            )
zlsh80826's avatar
zlsh80826 committed
1084
        if self.num_gqa_groups is None:
1085
            self.num_gqa_groups = self.num_attention_heads
1086
1087
1088
        super().__post_init__()

    @nn.compact
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1099
1100
1101
1102
1103
1104
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1105
        inputs_q: jax.numpy.ndarray
1106
            Input tensor for query projection.
1107
        inputs_kv: jax.numpy.ndarray
1108
            Input tensor for key/value projection.
1109
1110
1111
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
1112
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
1113
1114
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1115
        *
1116
        decode: bool, default = False
1117
            Indicate whether to prepare and use an autoregressive cache.
1118
        deterministic: bool, default = False
1119
1120
1121
1122
            Disable dropout layers if set to True.

        Returns
        -------
1123
        outputs: jax.numpy.ndarray
1124
1125
            Output tensors.
        """
1126

1127
1128
1129
1130
1131
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1132
        def query_init(*args):
1133
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1176
1177
1178
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1179

1180
1181
1182
1183
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1184
1185
1186
1187
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1188
1189
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1190
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1191
            if is_qkvpack:
1192
                qkv_proj, ln_out = LayerNormDenseGeneral(
1193
                    enable_layernorm=self.input_layernorm,
1194
                    layernorm_type=self.layernorm_type,
1195
                    zero_centered_gamma=self.zero_centered_gamma,
1196
1197
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1198
1199
                    features=(3, self.num_attention_heads * self.head_dim),
                    return_layernorm_output=self.return_layernorm_output,
1200
1201
1202
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1203
1204
1205
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1206
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1207
1208
1209
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1210
1211
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1212
                    transpose_batch_sequence=self.transpose_batch_sequence,
1213
1214
1215
1216
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1217
                qkv_layout = QKVLayout.BS3HD
1218
1219
            else:
                query, ln_out = LayerNormDenseGeneral(
1220
                    enable_layernorm=self.input_layernorm,
1221
                    layernorm_type=self.layernorm_type,
1222
                    zero_centered_gamma=self.zero_centered_gamma,
1223
1224
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1225
1226
                    features=self.num_attention_heads * self.head_dim,
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1227
1228
1229
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1230
1231
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1232
                    bias_axes=(W_TP_AXES,),
1233
1234
1235
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1236
1237
                    dtype=self.dtype,
                    kernel_init=query_init,
1238
1239
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1240
                    transpose_batch_sequence=self.transpose_batch_sequence,
1241
1242
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1243
1244
1245
1246
1247

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1259
                    transpose_batch_sequence=self.transpose_batch_sequence,
1260
1261
1262
1263
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1264
                qkv_layout = QKVLayout.BSHD_BS2HD
1265
1266
1267
1268
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1269
                features=self.num_gqa_groups * self.head_dim,
1270
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1271
1272
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1273
                bias_axes=(W_TP_AXES,),
1274
1275
1276
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1277
1278
                dtype=self.dtype,
            )
1279
            query, ln_out = LayerNormDenseGeneral(
1280
                enable_layernorm=self.input_layernorm,
1281
                layernorm_type=self.layernorm_type,
1282
                zero_centered_gamma=self.zero_centered_gamma,
1283
1284
                epsilon=self.layernorm_epsilon,
                axis=-1,
1285
                features=self.num_attention_heads * self.head_dim,
1286
                return_layernorm_output=True,
1287
1288
1289
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1290
1291
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1292
                bias_axes=(W_TP_AXES,),
1293
1294
1295
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1296
1297
                dtype=self.dtype,
                kernel_init=query_init,
1298
1299
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1300
                transpose_batch_sequence=self.transpose_batch_sequence,
1301
1302
                name="query",
            )(inputs_q)
1303

1304
            if is_self_attn:
1305
1306
1307
                assert ln_out is not None
                inputs_kv = ln_out

1308
            query = query.astype(input_dtype)
1309
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1310
            key = key.astype(input_dtype)
1311
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1312
            value = value.astype(input_dtype)
1313
1314
1315
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1316
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1317

1318
        if self.enable_rotary_pos_emb:
1319
1320
1321
1322
1323
1324
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1325

1326
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1327
1328
1329
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1342
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1343

1344
1345
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1346
1347
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1348
1349

        if decode:
1350
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1351
1352
1353
1354
1355
1356
1357
1358
1359
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1360
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1361
                if self.transpose_batch_sequence:
1362
1363
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1364
1365
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1366
1367
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1368
                    one_hot_indices_shape = (1, length, 1, 1)
1369
1370
1371
1372

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1373
1374
1375
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1376

1377
                cur_index = cache_index.value.astype(jnp.int32)
1378
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1379
1380
1381
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1382
1383
1384
1385
1386
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1387
1388
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1389
1390

                if bias is not None:
1391
1392
1393
1394
1395
1396
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1397

1398
1399
1400
1401
1402
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1403
1404
1405
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1417
        else:
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1428
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1442
            window_size=self.window_size,
1443
        )(*dpa_args, mask, bias, deterministic=deterministic)
1444
1445
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1446
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1447
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1448

1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1464

1465
1466
1467
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1468
        return out, ln_out
1469
1470


1471
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1472
1473
1474
1475
1476
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1477
    num_buckets: int
1478
        The number of buckets to bucket distances between key and query positions into.
1479
    max_distance: int
1480
        The maximum distance before everything is lumped into the last
1481
        distance bucket.
1482
    num_attention_heads: int
1483
        Number of attention heads in the transformer layer.
1484
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1485
        Used for initializing relative embedding tables.
1486
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1487
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1488
1489
1490

    Optimization parameters
    -----------------------
1491
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1492
        The data type used to allocate the initial parameters.
1493
    """
1494

1495
1496
1497
1498
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1499
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1500
1501
1502
1503
1504
1505
1506
1507
1508
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1509
        q_seqlen: int
1510
            The sequence length of query.
1511
        k_seqlen: int
1512
            The sequence length of key.
1513
        bidirectional: bool, default = True
1514
            Indicate whether to allow positive memory-query relative position
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1541
1542
1543
1544
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1545
1546
1547
1548
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
1549
        relative_attention_bias = self.param(
1550
            "rel_embedding",
1551
            nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
1552
            (self.num_attention_heads, self.num_buckets),
1553
            self.dtype,
1554
        )
1555
1556
1557
1558
1559
1560

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1561
1562
1563
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1564
1565
1566
1567
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1578

1579
1580
1581
1582
    ENCODER = "encoder"
    DECODER = "decoder"


1583
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1584
1585
1586
1587
1588
1589
1590
1591
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1592
        The hidden size of each input sample.
1593
    mlp_hidden_size: int, default = 2048
1594
        Intermediate size to which input samples are projected.
1595
    num_attention_heads: int, default = 8
1596
        Number of attention heads in the transformer layer.
1597
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1598
1599
1600
1601
1602
1603
1604
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1605
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1606
        Indicate the type of layer normalization.
1607
    layernorm_epsilon: float, default = 1e-6
1608
        A value added to the denominator of layer normalization for numerical stability.
1609
    zero_centered_gamma: bool, default = False
1610
1611
1612
1613
1614
1615
1616
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1617
    hidden_dropout: float, default = 0.1
1618
        Dropout probability for the dropout op after FC2 layer.
1619
    hidden_dropout_dims: Sequence[int], default = ()
1620
        Dimensions that will share the same dropout mask for hidden
1621
    attention_dropout: float, default = 0.1
1622
        Dropout probability for the dropout op during multi-head attention.
1623
1624
1625
1626
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1627
    dropout_rng_name: str, default = 'dropout'
1628
1629
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1630
1631
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1632
1633
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1634
1635
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1636
1637
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1638
    mlp_activations: Sequence[str], default = ('relu', )
1639
        The sequence of activation functions to apply after the first linear transformation.
1640
        Each activation has its own transformation layer.
1641
1642
1643
    mlp_activation_params: dict = None
         This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
        ClampedSwiglu is the only activation that requires parameters.
1644
    use_bias: bool, default = False
1645
1646
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1647
    bias_init: Initializer, default = flax.linen.initializers.zeros
1648
1649
1650
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1651
    apply_residual_connection_post_layernorm: bool, default = False
1652
        If set to True, residual connections are taken from the output
1653
1654
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1655
        If set to True, layer normalization is applied on the output side,
1656
1657
1658
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1659
1660
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1661
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1662
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1663
1664
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1665
    self_attn_mask_type: str, default = 'causal'
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1685
1686
1687
1688
1689
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1690
    enable_relative_embedding: bool, default = True
1691
        Whether to enable relative embedding as shifting of attention logits.
1692
    relative_embedding: flax.linen.Module, default = None
1693
        The module for relative embedding execution, only used when
1694
1695
1696
1697
1698
1699
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1700
1701
1702
1703
1704
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1705
    rotary_pos_emb_group_method: str, default = 'consecutive'
1706
1707
1708
1709
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1710
1711
1712
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1713
        'exclude_output_proj', 'exclude_mlp']
1714
1715
1716
1717
1718
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1719
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1720
1721
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1722
1723
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1724
1725
1726

    Optimization parameters
    -----------------------
1727
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1728
        The data type used to allocate the initial parameters.
1729
    drop_path: float, default = 0.0
1730
        When > 0.0, applies stochastic depth per sample in the main
1731
1732
        path of the residual block.
    fuse_qkv_params: bool, default = True
1733
        If set to True, `TransformerLayer` module exposes a single fused
1734
1735
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1736
    transpose_batch_sequence: bool, default = False
1737
        Indicate whether the input tensors were switched axis of batch
1738
1739
1740
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1741
        Indicate whether to scale attention logits.
1742
1743
1744
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1745
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1746
1747
1748
1749
1750
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1751
    num_gqa_groups: Optional[int] = None
1752
    layernorm_type: str = "layernorm"
1753
    layernorm_epsilon: float = 1e-6
1754
    zero_centered_gamma: bool = False
1755
1756
1757
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1758
1759
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1760
    dropout_rng_name: str = "dropout"
1761
1762
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1763
    mlp_activations: Sequence[str] = ("relu",)
1764
    mlp_activation_params: dict = None
1765
1766
1767
1768
1769
1770
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1771
    self_attn_mask_type: str = "causal"
1772
    self_attn_bias_type: Optional[str] = None
1773
1774
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1775
1776
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1777
1778
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1779
1780
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1781
1782
1783
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1784
    transpose_batch_sequence: bool = False
1785
    enable_sequence_parallel: bool = False
1786
1787
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1788
    window_size: Optional[Tuple[int, int]] = None
1789
1790
1791

    def __post_init__(self):
        if self.mha_kernel_init is None:
1792
            self.mha_kernel_init = nn.initializers.variance_scaling(
1793
                1.0, "fan_in", "normal", dtype=self.dtype
1794
            )
1795
        if self.mlp_kernel_init is None:
1796
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1797
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1798
            )
zlsh80826's avatar
zlsh80826 committed
1799
1800
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1801
1802
1803
        super().__post_init__()

    @nn.compact
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1814
1815
1816
1817
1818
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1819
        inputs: jax.numpy.ndarray
1820
            Input tensor.
1821
        encoded: jax.numpy.ndarray, default = None
1822
1823
1824
1825
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1826
1827
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1828
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1829
1830
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1831
            :attr:`True` means mask out the corresponding values.
1832
        deterministic: bool, default = False
1833
            Disable dropout layers if set to True.
1834
        decode: bool, default = False
1835
1836
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1837
        max_decode_length: bool, default = None
1838
1839
1840
1841
1842
1843
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1844
        outputs: jax.numpy.ndarray
1845
            Output tensors.
1846
        """
1847

1848
        input_dtype = inputs.dtype
1849
1850
1851
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1852

1853
1854
1855
1856
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1857

1858
1859
1860
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1861
1862
1863
1864
1865
1866

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1867
1868
1869
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1870
1871
1872
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1873
1874
1875
1876
1877

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1878
1879
1880
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1881
1882
1883
1884
1885
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
1886
1887
1888
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
1889
1890
                    name="relpos_bias",
                )
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
1908
            mha_name = "attention"
1909
        else:
1910
            mha_name = "self_attention"
1911

1912
        inputs = with_sharding_constraint_by_logical_axes(
1913
1914
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1915

1916
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1917
1918
1919
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1920
1921
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1922
            num_gqa_groups=self.num_gqa_groups,
1923
            transpose_batch_sequence=self.transpose_batch_sequence,
1924
            enable_sequence_parallel=self.enable_sequence_parallel,
1925
            attention_dropout=self.attention_dropout,
1926
1927
1928
1929
1930
1931
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1932
            zero_centered_gamma=self.zero_centered_gamma,
1933
1934
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1935
            attn_mask_type=self.self_attn_mask_type,
1936
            attn_bias_type=self.self_attn_bias_type,
1937
1938
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1939
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1940
1941
1942
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1943
            fuse_qkv_params=self.fuse_qkv_params,
1944
1945
1946
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1947
            name=mha_name,
1948
            window_size=self.window_size,
1949
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
1950
1951
1952
1953
1954

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1955
                assert -x_shape_len <= dims < x_shape_len
1956

1957
1958
1959
1960
1961
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1962

1963
        x = with_sharding_constraint_by_logical_axes(
1964
1965
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1966
        residual = with_sharding_constraint_by_logical_axes(
1967
1968
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1969

1970
1971
1972
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1973
1974
1975
1976
1977
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1978
1979
1980
1981
1982

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1983
1984
1985
1986
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
1987
1988
1989
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
1990

1991
            x = with_sharding_constraint_by_logical_axes(
1992
1993
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1994

1995
1996
1997
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1998
1999
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2000
                num_gqa_groups=self.num_gqa_groups,
2001
                transpose_batch_sequence=self.transpose_batch_sequence,
2002
                enable_sequence_parallel=self.enable_sequence_parallel,
2003
                attention_dropout=self.attention_dropout,
2004
2005
2006
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
2007
                zero_centered_gamma=self.zero_centered_gamma,
2008
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
2009
2010
2011
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
2012
2013
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2014
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2015
2016
2017
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2018
2019
2020
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
2021
                fuse_qkv_params=self.fuse_qkv_params,
2022
2023
2024
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
2025
                name="encoder_decoder_attention",
2026
                window_size=self.window_size,
2027
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
2028
2029

            y = with_sharding_constraint_by_logical_axes(
2030
2031
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2032
            residual = with_sharding_constraint_by_logical_axes(
2033
2034
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2035

2036
            y = hidden_dropout(y, deterministic)
2037
2038
2039
2040
2041

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

2042
2043
            mlp_input = y + residual

2044
        mlp_input = with_sharding_constraint_by_logical_axes(
2045
2046
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2047

2048
2049
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2050
2051
2052
2053
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2054
            zero_centered_gamma=self.zero_centered_gamma,
2055
2056
2057
2058
            epsilon=self.layernorm_epsilon,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2059
            activation_params=self.mlp_activation_params,
2060
2061
2062
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2063
            dtype=self.dtype,
2064
2065
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2066
            kernel_init=self.mlp_kernel_init,
2067
2068
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2069
2070
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2071
2072
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2073
2074
2075
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2076
2077
2078
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2079
            transpose_batch_sequence=self.transpose_batch_sequence,
2080
            name="mlp",
2081
2082
2083
2084
2085
2086
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2087
        z = with_sharding_constraint_by_logical_axes(
2088
2089
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2090
        residual = with_sharding_constraint_by_logical_axes(
2091
2092
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2093

2094
2095
2096
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2097
2098
2099
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2100
2101
2102
        z = z + residual

        if self.output_layernorm:
2103
            z = with_sharding_constraint_by_logical_axes(
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2115
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2116
        return z