transformer.py 102 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
18
from flax.linen.attention import combine_masks
19
20
21
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
27
28
29
30
31
32
from ..attention import (
    AttnBiasType,
    AttnMaskType,
    AttnSoftmaxType,
    QKVLayout,
    SequenceDescriptor,
)
33
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
34
from ..attention import fused_attn
35
from ..attention import CPStrategy
36
from ..softmax import SoftmaxFusionType
37
38
39
40
41
42
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
43
44
45
46
47

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
48
49
50
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
51
52
53
54
55
56
57
58
59
60
61
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


Phuong Nguyen's avatar
Phuong Nguyen committed
62
# TODO(Phuong): move this function to sharding.py
63
64
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
65
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
66
67
68
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
70
71
72
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
73
74

    .. warning::
75
        Please make sure ShardingResource is set via autocast before calling this function.
76

Ming-Xu Huang's avatar
Ming-Xu Huang committed
77
78
79
80
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

81
82
    Parameters
    ----------
83
    rules: Sequence[Tuple[str, Union[str, None]]]
84
85
86
87
        the base Flax logical axis rules to extend.

    Returns
    -------
88
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
89
90
91
92
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
93
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
94
95
        key = item[0]
        val = item[1]
96
97
98
99
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
100
101
102
103
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
104
105

    extended_rules = [*rules]
106
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
107
108
109
        key = item[0]
        val = item[1]
        if key in rules_map:
110
111
112
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
113
                f" rules, but {val} in TE's rules."
114
            )
115
116
117
118
119
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


120
121
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
122
123
124
125
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    float32_logits: bool = False
    scale_factor: Optional[float] = None
126
    transpose_batch_sequence: bool = False
127
    window_size: Optional[Tuple[int, int]] = None
128
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
129
130

    @nn.compact
131
132
133
134
135
136
137
138
139
140
141
142
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
143
        batch_dim = 1 if self.transpose_batch_sequence else 0
144
145
146
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
147
        sequence_dim = 0 if self.transpose_batch_sequence else 1
148
149
150
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
151

152
153
        input_dtype = query.dtype

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        # Infer number of attention heads from query shape
        # query shape: [..., h, d] where h is num_attention_heads
        num_attention_heads = query.shape[-2]

        # Initialize softmax_offset for learnable softmax
        # Note: OFF_BY_ONE_SOFTMAX is handled internally by the Softmax module
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

170
171
172
173
174
175
176
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
177
178
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
179
180
181
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
182
        is_gqa = h_q != h_kv
183

184
        if is_gqa:
185
186
187
188
189
190
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
191
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
192
            else:
193
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
194
        else:
195
            if is_gqa:
196
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
197
            else:
198
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
199

200
        attn_weights = checkpoint_name(attn_weights, "logits")
201

202
        if is_gqa:
203
204
205
206
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

207
        # (b, h, q, k): Last two axes are always replicated
208
        attn_weights = with_sharding_constraint_by_logical_axes(
209
            attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
210
        )
211
212
213
214
215

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
216
            fused_scale_factor = 1.0
217
        else:
218
219
220
221
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias
222
                bias = None
223

224
        def apply_swa_mask(original_mask: Array) -> Array:
225
            """Apply the sliding window mask to a given mask"""
226
            batch = original_mask.shape[0]
227
228
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
229
230
231
232
233
234
235
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
236
237
            return new_mask

238
239
        def convert_to_softmax_fusion_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxFusionType"""
240
241
242
243
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
244
                mask = None
245
            if mask is not None:
246
                mask = apply_swa_mask(mask)
247
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
248
            if mask is not None:
249
                return SoftmaxFusionType.SCALED_MASKED, mask
250
            if attn_mask_type is AttnMaskType.CAUSAL_MASK:
251
                return SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, mask
252
            if attn_mask_type is AttnMaskType.NO_MASK:
253
                return SoftmaxFusionType.SCALED, mask
254
255
256
257
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
258

259
        softmax_fusion_type, mask = convert_to_softmax_fusion_type(self.attn_mask_type, mask)
260

261
262
263
264
265
        attn_weights = Softmax(
            softmax_fusion_type=softmax_fusion_type,
            softmax_type=self.softmax_type,
            scale_factor=fused_scale_factor,
        )(attn_weights, mask, bias, softmax_offset=softmax_offset).astype(input_dtype)
266

267
        if is_gqa:
268
269
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

270
        if not deterministic and self.attention_dropout > 0.0:
271
272
273
274
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
275
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
276
277
            attn_weights = attn_weights * multiplier

278
279
280
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
281
282
        if self.transpose_batch_sequence:
            if is_gqa:
283
284
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
285
286

        if is_gqa:
287
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
288

289
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
290
291


292
293
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
294
295
296
297
298
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
299
    window_size: Optional[Tuple[int, int]] = None
300
    max_segments_per_seq: Optional[int] = 1
301
302
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
303
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
304
    context_checkpoint_name: str = "context"
305
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
306
307

    @nn.compact
308
309
310
311
312
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
313
        sequence_descriptor: Optional[SequenceDescriptor] = None,
314
315
316
317
318
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
319
320
321
322
323
324
325
326
327
328
329

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

330
331
332
333
334
335
336
337
338
339
340
        num_attention_heads = query.shape[-2]
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

341
        if self.qkv_layout.is_qkvpacked():
342
343
344
345
346
347
348
349
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
350
351
            x = fused_attn(
                (qkv_packed,),
352
                bias,
353
                sequence_descriptor,
354
355
356
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
357
                qkv_layout=self.qkv_layout,
358
                softmax_type=self.softmax_type,
359
360
361
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
362
                window_size=self.window_size,
363
                max_segments_per_seq=self.max_segments_per_seq,
364
365
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
366
                context_parallel_strategy=self.context_parallel_strategy,
367
                context_checkpoint_name=self.context_checkpoint_name,
368
                softmax_offset=softmax_offset,
369
            )
370
        elif self.qkv_layout.is_kvpacked():
371
372
373
374
375
376
377
378
379
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
380
381
            x = fused_attn(
                (query, kv_packed),
382
                bias,
383
                sequence_descriptor,
384
385
386
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
387
                qkv_layout=self.qkv_layout,
388
                softmax_type=self.softmax_type,
389
390
391
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
392
                window_size=self.window_size,
393
                max_segments_per_seq=self.max_segments_per_seq,
394
395
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
396
                context_parallel_strategy=self.context_parallel_strategy,
397
                context_checkpoint_name=self.context_checkpoint_name,
398
                softmax_offset=softmax_offset,
399
            )
400
        elif self.qkv_layout.is_separate():
401
402
403
404
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
405
            x = fused_attn(
406
                (query, key, value),
407
                bias,
408
                sequence_descriptor,
409
410
411
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
412
                qkv_layout=self.qkv_layout,
413
                softmax_type=self.softmax_type,
414
415
416
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
417
                window_size=self.window_size,
418
                max_segments_per_seq=self.max_segments_per_seq,
419
420
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
421
                context_parallel_strategy=self.context_parallel_strategy,
422
                context_checkpoint_name=self.context_checkpoint_name,
423
                softmax_offset=softmax_offset,
424
            )
425
426
427
428
429
430
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

431
        assert x.dtype == query.dtype
432
433
434
        return x


435
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

451
452
453
454
        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
          attention kernel is not available on the system, a warning will be issued, and the module
          will automatically fall back to the unfused backend.
455

456
457
458
459
460
461
462
463
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

464
465
466
467
468
469
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
Paweł Gadziński's avatar
Paweł Gadziński committed
470
    num_gqa_groups: int, default = None
471
472
473
474
475
476
477
478
479
480
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
481
482
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
483
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
484
485
486

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
487
        * ``no_mask``: No attention mask is applied. This means the attention will consider the
488
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
489
490
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
491
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
492
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
493
494
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
495
496
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
497

Paweł Gadziński's avatar
Paweł Gadziński committed
498
        |
499

Paweł Gadziński's avatar
Paweł Gadziński committed
500
        .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
501

Paweł Gadziński's avatar
Paweł Gadziński committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        |

        .. note:: THD format only supports ``'padding'`` or ``'causal_padding'`` mask type.

        |

        .. table::
            :widths: auto

            ================== ============ ========== ==============================
            attn_mask_type     mask/sd      SWA        softmax type
            ================== ============ ========== ==============================
            no_mask            None         None       SCALED
            causal             None         None       SCALED_UPPER_TRIANG_MASKED
            causal             None         Yes        SCALED_MASKED
            padding            Required     Yes/No     SCALED_MASKED
            padding_causal     Required     Yes/No     SCALED_MASKED
            ================== ============ ========== ==============================

        where sd stands for sequence_descriptor.
522

523
    attn_bias_type: Optional[str], default = None
524
        Type of the attention bias passed in the attention.
525
526
527
528
529
530
531
532
533
534
535
536
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
537
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
538
539
540
541
542
543

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
544
545
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
546
547
548
549
550
551
552
553
554
555
556
557

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
558
559
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
560
        Indicate whether the input tensors were switched axis of batch
561
        and sequence length dimension. If set to True, the input tensors
562
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
563
564
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
565
566
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
Paweł Gadziński's avatar
Paweł Gadziński committed
567
568
569
570
571
572
573
574
    context_parallel_causal_load_balanced: bool
        Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis: str
        The name of the context parallel axis.
    context_parallel_strategy: CPStrategy
        The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
    context_checkpoint_name: str
        The name of the context checkpoint in the forward pass of fused attention.
575
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
576
        Softmax type as described in the paper
577
578
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
601
602
603
604
605

    Optimization parameters
    -----------------------
    dtype(deprecated): jax.numpy.dtype, default  = None
        This dtype is deprecated and will be removed in a future release. DPA will use the dtype of the inputs instead as this module does not have any parameters.
606
    """
607

608
609
610
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
611
612
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
613
    attn_bias_type: AttnBiasType = None
614
    dtype: Optional[DType] = None  # Deprecated
615
    dropout_rng_name: str = "dropout"
616
    float32_logits: bool = False
617
    qkv_layout: str = "bshd_bshd_bshd"
618
    scale_factor: Optional[float] = None
619
    transpose_batch_sequence: bool | None = None
620
    window_size: Optional[Tuple[int, int]] = None
621
    max_segments_per_seq: Optional[int] = 1
622
623
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
624
    context_parallel_strategy: str = "DEFAULT"
625
    context_checkpoint_name: str = "context"
626
    softmax_type: str = "vanilla"
627

628
629
630
631
632
633
634
635
636
637
638
    def __post_init__(self):
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in DotProductAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
        super().__post_init__()

639
640
641
642
643
644
645
    def _assert_dtypes(self, query: Array, key: Array, value: Array, qkv_layout: QKVLayout):
        """Asserts that the dtypes of query, key, and value dtypes are consistent."""
        if qkv_layout.is_qkvpacked():
            pass  # No need to check dtypes for key and value since it is packed
        elif qkv_layout.is_kvpacked():
            assert (
                key.dtype == query.dtype
646
            ), f"Expected kv {key.dtype=} to match query {query.dtype=}."
647
648
649
        elif qkv_layout.is_separate():
            assert (
                key.dtype == query.dtype
650
            ), f"Expected key {key.dtype=} to match query {query.dtype=}."
651
652
            assert (
                value.dtype == query.dtype
653
            ), f"Expected value {value.dtype=} to match query {query.dtype=}."
654
655
656
        else:
            raise ValueError(f"Unsupported {qkv_layout=}.")

657
    @nn.compact
658
659
660
661
662
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
663
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
664
665
666
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
667
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
668
    ) -> Array:
669
670
671
672
673
674
675
676
677
678
679
680
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
681
            Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
682
683
684
685
686
687
688
689
690
691
692
693
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
694
        input_dtype = query.dtype
695

696
697
698
699
700
701
702
703
704
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

705
706
707
708
709
710
711
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
712
        softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
713
714
715
716
717
718
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None
719
720
721
            bias = bias.astype(input_dtype)

        self._assert_dtypes(query, key, value, qkv_layout)
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        if self.dtype is not None:
            if self.dtype == input_dtype:
                warnings.warn(
                    "The dtype argument is deprecated and will be removed in a future release."
                    " DotProductAttention will use the dtype of the inputs instead as this module"
                    f" does not have any parameters. Module dtype specified {self.dtype=} matches"
                    " dtype of inputs so behavior is unchanged. Please remove the dtype argument"
                    " within the next few releases."
                )
            else:
                raise ValueError(
                    "The DotProductAttention module dtype is deprecated and will be removed in a"
                    " future release. DotProductAttention will use the dtype of the inputs instead"
                    " as this module does not have any parameters. Module dtype specified"
                    f" {self.dtype=} does not match dtype of inputs  {input_dtype=}."
                )
738

739
740
        # Use fused attn (if kernel check below passes) by default
        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
741
742
743
744
745
746
747

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]
748
749
750
751
752
753
        if qkv_layout.is_separate():
            head_dim_qk = query.shape[-1]
            head_dim_v = value.shape[-1]
        else:
            head_dim_qk = self.head_dim
            head_dim_v = self.head_dim
754

755
        has_fused_attn_kernel = is_fused_attn_kernel_available(
756
757
            # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
            not deterministic,
758
759
760
            input_dtype,
            # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient
            input_dtype,
761
762
763
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
764
            softmax_type,
765
766
767
768
769
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
770
771
            head_dim_qk,
            head_dim_v,
772
            self.window_size,
773
        )
774

775
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
776
777

        if enable_fused_attn and not has_fused_attn_kernel:
778
779
780
781
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
782
                f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
783
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
784
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
785
            )
786
787

        dropout_rng = None
788
        if not deterministic and self.attention_dropout > 0.0:
789
790
791
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
792
            scale_factor = 1.0 / sqrt(head_dim_qk)
793
794
795
796
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
        # case-insensitive mapping for context parallel strategy
        cp_strategy_map = {
            "DEFAULT": CPStrategy.DEFAULT,
            "ALL_GATHER": CPStrategy.ALL_GATHER,
            "ALLGATHER": CPStrategy.ALL_GATHER,  # Alternative spelling
            "RING": CPStrategy.RING,
        }

        strategy_key = self.context_parallel_strategy.upper()
        if strategy_key in cp_strategy_map:
            context_parallel_strategy = cp_strategy_map[strategy_key]
        else:
            valid_strategies = list(cp_strategy_map.keys())
            raise ValueError(
                f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
                f"Valid options are: {valid_strategies} (case insensitive)"
            )

815
816
        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
817
            if qkv_layout.is_qkvpacked():
818
                query, key, value = jnp.split(query, [1, 2], axis=-3)
819
820
821
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
822
            elif qkv_layout.is_kvpacked():
823
824
825
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
826
827
                assert qkv_layout.is_separate()

828
829
830
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
831

832
833
834
835
836
837
838
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
839
                window_size=self.window_size,
840
                softmax_type=softmax_type,
841
842
843
844
845
846
847
848
849
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
850
851
852
853
854
855
856
857
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
858
                window_size=self.window_size,
859
                max_segments_per_seq=self.max_segments_per_seq,
860
861
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
862
                context_parallel_strategy=context_parallel_strategy,
863
                context_checkpoint_name=self.context_checkpoint_name,
864
                softmax_type=softmax_type,
865
866
867
868
869
870
871
872
873
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
874
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
875
        return x
876
877


878
879
880
881
882
883
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
884
885
    """
    Rotary Positional Embedding
Paweł Gadziński's avatar
Paweł Gadziński committed
886
    x should be of shape
887
888
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
889
    """
890
891
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
892
893
894
    min_window = windows[0]
    max_window = windows[1]

895
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
896
    time_scales = min_window * (max_window / min_window) ** fraction
897
898
899
900
901
902
903
904
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

905
906
907
908
909
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
910

911
912
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
913

914
        x1, x2 = jnp.split(x, 2, axis=-1)
915
916
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
917

918
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
942
943
944
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
945
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
946
        )
947
948
949
950
951

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

952
    if group_method == "alternate":
953
954
        return alternate_impl()
    return consecutive_impl()
955
956


957
class LoRAScope:  # pylint: disable=too-few-public-methods
958
959
960
961
962
963
964
965
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
966
967
968
969
970
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
971
972
973
974


def _canonicalize_lora_scope(scope):

975
976
977
978
979
980
981
982
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
983
984
985
986
987
988

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
989
990
991
992
993
994
995
996
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


1013
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
1014
1015
1016
1017
1018
1019
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
1020
    head_dim: int
1021
        The hidden dimension of each attention head.
1022
1023
    num_attention_heads: int
        The number of attention heads.
Paweł Gadziński's avatar
Paweł Gadziński committed
1024
    num_gqa_groups: int, default = None
1025
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
1026
1027
1028
1029
1030
1031
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1032
1033
1034
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
1035
1036
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
1037
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
1038
1039
1040

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
1041
        * ``no_mask``: No attention mask is applied. This means the attention will consider the
1042
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1043
1044
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
1045
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1046
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
1047
1048
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
1049
1050
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
1051

Paweł Gadziński's avatar
Paweł Gadziński committed
1052
        .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
1053

1054
1055
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1056
        Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
1057
        When default is present, the type is automatically decided by the MHA's bias parameter.
Paweł Gadziński's avatar
Paweł Gadziński committed
1058
        Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
1059
    dropout_rng_name: str, default = 'dropout'
1060
1061
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
1062
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1063
        Indicate the type of layer normalization.
1064
    layernorm_epsilon: float, default = 1e-6
1065
        A value added to the denominator of layer normalization for numerical stability.
1066
    zero_centered_gamma: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1067
        If set to ``True``, the LayerNorm formula changes to
1068
1069

        .. math::
Paweł Gadziński's avatar
Paweł Gadziński committed
1070
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
1071
1072
            (1 + \gamma) + \beta

Paweł Gadziński's avatar
Paweł Gadziński committed
1073
        This parameter is only applicable for ``'layernorm'``.
1074
    kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1075
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
1076
        Used for initializing the QKV and output projection weights.
Paweł Gadziński's avatar
Paweł Gadziński committed
1077
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1078
    use_bias: bool, default = False
1079
        Indicate whether or not to enable bias shifting for QKV and output projections.
Paweł Gadziński's avatar
Paweł Gadziński committed
1080
1081
        If set to ``False``, the layer will not learn additive biases.
    bias_init: Initializer, default = ``flax.linen.initializers.zeros``
1082
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1083
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1084
    input_layernorm: bool, default = True
Paweł Gadziński's avatar
Paweł Gadziński committed
1085
        If set to ``False``, layer normalization to the input is not applied.
1086
    return_layernorm_output: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1087
        If set to ``True``, output of layernorm is returned from the forward together with the output
1088
1089
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
1090
1091
1092
1093
1094
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1095
1096
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1097
1098
        ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`.
1099
1100
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1101
        ``['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']``
1102
1103
1104
1105
1106
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
Paweł Gadziński's avatar
Paweł Gadziński committed
1107
        :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
1108
1109
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1110
1111
1112
1113
1114
1115
1116
1117
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
1118
1119
1120

    Optimization parameters
    -----------------------
1121
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1122
        The data type used to allocate the initial parameters.
1123
    fuse_qkv_params: bool, default = True
1124
        If set to True, this module exposes a single fused
1125
1126
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1127
1128
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
1129
        Indicate whether the input tensors were switched axis of batch
1130
1131
1132
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1133
        Indicate whether to scale attention logits.
Paweł Gadziński's avatar
Paweł Gadziński committed
1134
1135
        If set to True, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
        else :math:`Q \cdot K^T`
1136
1137
1138
1139
1140
1141
1142
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
1143
1144
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1145
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
1146
        Softmax type as described in the paper
1147
1148
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
1171
1172
1173
    """

    head_dim: int
1174
1175
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
1176
1177
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
1178
    input_layernorm: bool = True
1179
1180
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
1181
    return_layernorm_output: bool = False
1182
    zero_centered_gamma: bool = False
1183
1184
1185
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
1186
    attn_mask_type: str = "causal"
1187
    attn_bias_type: Optional[str] = None
1188
1189
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1190
1191
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1192
1193
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1194
    dtype: DType = jnp.float32
1195
    fuse_qkv_params: bool = True
1196
    transpose_batch_sequence: bool | None = None
1197
    enable_sequence_parallel: bool = False
1198
1199
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1200
    float32_logits: bool = False
1201
    window_size: Optional[Tuple[int, int]] = None
1202
    softmax_type: str = "vanilla"
1203
1204
1205
1206
1207
1208
1209

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
1210
1211

    def __post_init__(self):
1212
1213
1214
1215
1216
1217
1218
1219
1220
        # Deal with changed defaults in API
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in MultiHeadAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
1221
1222
1223
1224
1225
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1226
1227
1228
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1229
1230
1231
1232
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1233
1234
1235
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1236
1237
1238
1239
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1240
1241
                DeprecationWarning,
            )
1242
1243
1244
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1245
1246
1247
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1248
1249
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1250
1251
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1252

1253
        if self.kernel_init is None:
1254
            self.kernel_init = nn.initializers.variance_scaling(
1255
                1.0, "fan_in", "normal", dtype=self.dtype
1256
            )
zlsh80826's avatar
zlsh80826 committed
1257
        if self.num_gqa_groups is None:
1258
            self.num_gqa_groups = self.num_attention_heads
1259
1260
1261
        super().__post_init__()

    @nn.compact
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1272
1273
1274
1275
1276
1277
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1278
        inputs_q: jax.numpy.ndarray
1279
            Input tensor for query projection.
1280
        inputs_kv: jax.numpy.ndarray
1281
            Input tensor for key/value projection.
1282
1283
1284
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
1285
            Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
1286
1287
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1288
        *
1289
        decode: bool, default = False
1290
            Indicate whether to prepare and use an autoregressive cache.
1291
        deterministic: bool, default = False
1292
1293
1294
1295
            Disable dropout layers if set to True.

        Returns
        -------
1296
        outputs: jax.numpy.ndarray
1297
1298
            Output tensors.
        """
1299

1300
1301
1302
1303
1304
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1305
        def query_init(*args):
1306
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1349
1350
1351
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1352

1353
1354
1355
1356
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1357
1358
1359
1360
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1361
1362
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1363
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1364
            if is_qkvpack:
1365
                qkv_proj, ln_out = LayerNormDenseGeneral(
1366
                    enable_layernorm=self.input_layernorm,
1367
                    layernorm_type=self.layernorm_type,
1368
                    zero_centered_gamma=self.zero_centered_gamma,
1369
1370
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1371
1372
                    features=(3, self.num_attention_heads * self.head_dim),
                    return_layernorm_output=self.return_layernorm_output,
1373
1374
1375
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1376
1377
1378
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1379
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1380
1381
1382
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1383
1384
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1385
                    transpose_batch_sequence=self.transpose_batch_sequence,
1386
1387
1388
1389
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1390
                qkv_layout = QKVLayout.BS3HD
1391
1392
            else:
                query, ln_out = LayerNormDenseGeneral(
1393
                    enable_layernorm=self.input_layernorm,
1394
                    layernorm_type=self.layernorm_type,
1395
                    zero_centered_gamma=self.zero_centered_gamma,
1396
1397
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1398
1399
                    features=self.num_attention_heads * self.head_dim,
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1400
1401
1402
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1403
1404
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1405
                    bias_axes=(W_TP_AXES,),
1406
1407
1408
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1409
1410
                    dtype=self.dtype,
                    kernel_init=query_init,
1411
1412
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1413
                    transpose_batch_sequence=self.transpose_batch_sequence,
1414
1415
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1416
1417
1418
1419
1420

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1432
                    transpose_batch_sequence=self.transpose_batch_sequence,
1433
1434
1435
1436
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1437
                qkv_layout = QKVLayout.BSHD_BS2HD
1438
1439
1440
1441
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1442
                features=self.num_gqa_groups * self.head_dim,
1443
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1444
1445
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1446
                bias_axes=(W_TP_AXES,),
1447
1448
1449
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1450
1451
                dtype=self.dtype,
            )
1452
            query, ln_out = LayerNormDenseGeneral(
1453
                enable_layernorm=self.input_layernorm,
1454
                layernorm_type=self.layernorm_type,
1455
                zero_centered_gamma=self.zero_centered_gamma,
1456
1457
                epsilon=self.layernorm_epsilon,
                axis=-1,
1458
                features=self.num_attention_heads * self.head_dim,
1459
                return_layernorm_output=True,
1460
1461
1462
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1463
1464
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1465
                bias_axes=(W_TP_AXES,),
1466
1467
1468
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1469
1470
                dtype=self.dtype,
                kernel_init=query_init,
1471
1472
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1473
                transpose_batch_sequence=self.transpose_batch_sequence,
1474
1475
                name="query",
            )(inputs_q)
1476

1477
            if is_self_attn:
1478
1479
1480
                assert ln_out is not None
                inputs_kv = ln_out

1481
            query = query.astype(input_dtype)
1482
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1483
            key = key.astype(input_dtype)
1484
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1485
            value = value.astype(input_dtype)
1486
1487
1488
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1489
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1490

1491
        if self.enable_rotary_pos_emb:
1492
1493
1494
1495
1496
1497
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1498

1499
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1500
1501
1502
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1515
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1516

1517
1518
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1519
1520
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1521
1522

        if decode:
1523
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1524
1525
1526
1527
1528
1529
1530
1531
1532
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1533
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1534
                if self.transpose_batch_sequence:
1535
1536
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1537
1538
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1539
1540
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1541
                    one_hot_indices_shape = (1, length, 1, 1)
1542
1543
1544
1545

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1546
1547
1548
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1549

1550
                cur_index = cache_index.value.astype(jnp.int32)
1551
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1552
1553
1554
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1555
1556
1557
1558
1559
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1560
1561
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1562
1563

                if bias is not None:
1564
1565
1566
1567
1568
1569
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1570

1571
1572
1573
1574
1575
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1576
1577
1578
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1590
        else:
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1601
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1614
            window_size=self.window_size,
1615
            softmax_type=self.softmax_type,
1616
        )(*dpa_args, mask, bias, deterministic=deterministic)
1617
1618
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1619
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1620
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1621

1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1637

1638
1639
1640
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1641
        return out, ln_out
1642
1643


1644
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1645
1646
1647
1648
1649
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1650
    num_buckets: int
1651
        The number of buckets to bucket distances between key and query positions into.
1652
    max_distance: int
1653
        The maximum distance before everything is lumped into the last
1654
        distance bucket.
1655
    num_attention_heads: int
1656
        Number of attention heads in the transformer layer.
1657
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1658
        Used for initializing relative embedding tables.
1659
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1660
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1661
1662
1663

    Optimization parameters
    -----------------------
1664
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1665
        The data type used to allocate the initial parameters.
1666
    """
1667

1668
1669
1670
1671
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1672
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1673
1674
1675
1676
1677
1678
1679
1680
1681
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1682
        q_seqlen: int
1683
            The sequence length of query.
1684
        k_seqlen: int
1685
            The sequence length of key.
1686
        bidirectional: bool, default = True
1687
            Indicate whether to allow positive memory-query relative position
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1714
1715
1716
1717
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1718
1719
1720
1721
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
1722
        relative_attention_bias = self.param(
1723
            "rel_embedding",
1724
            nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
1725
            (self.num_attention_heads, self.num_buckets),
1726
            self.dtype,
1727
        )
1728
1729
1730
1731
1732
1733

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1734
1735
1736
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1737
1738
1739
1740
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1751

1752
1753
1754
1755
    ENCODER = "encoder"
    DECODER = "decoder"


1756
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1757
1758
1759
1760
1761
1762
1763
1764
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1765
        The hidden size of each input sample.
1766
    mlp_hidden_size: int, default = 2048
1767
        Intermediate size to which input samples are projected.
1768
    num_attention_heads: int, default = 8
1769
        Number of attention heads in the transformer layer.
Paweł Gadziński's avatar
Paweł Gadziński committed
1770
    num_gqa_groups: int, default = None
zlsh80826's avatar
zlsh80826 committed
1771
1772
1773
1774
1775
1776
1777
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1778
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1779
        Indicate the type of layer normalization.
1780
    layernorm_epsilon: float, default = 1e-6
1781
        A value added to the denominator of layer normalization for numerical stability.
1782
    zero_centered_gamma: bool, default = False
1783
1784
1785
1786
1787
1788
1789
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1790
    hidden_dropout: float, default = 0.1
1791
        Dropout probability for the dropout op after FC2 layer.
1792
    hidden_dropout_dims: Sequence[int], default = ()
1793
        Dimensions that will share the same dropout mask for hidden
1794
    attention_dropout: float, default = 0.1
1795
        Dropout probability for the dropout op during multi-head attention.
1796
    intermediate_dropout: float, default = 0.0
1797
1798
1799
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1800
    dropout_rng_name: str, default = 'dropout'
1801
1802
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1803
    mha_kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1804
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
1805
        Used for initializing weights of QKV and Output projection weights.
Paweł Gadziński's avatar
Paweł Gadziński committed
1806
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1807
    mlp_kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1808
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')``
1809
        Used for initializing weights of FC1 and FC2 layers.
Paweł Gadziński's avatar
Paweł Gadziński committed
1810
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1811
    mlp_activations: Sequence[str], default = ('gelu', )
1812
        The sequence of activation functions to apply after the first linear transformation.
1813
        Each activation has its own transformation layer.
1814
    mlp_activation_params: dict = None
Paweł Gadziński's avatar
Paweł Gadziński committed
1815
1816
         This is only used when ``('clamped_silu', 'clamped_linear')`` is in :attr:`mlp_activations`. At the moment
        ``ClampedSwiglu`` is the only activation that requires parameters.
1817
    use_bias: bool, default = False
1818
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
Paweł Gadziński's avatar
Paweł Gadziński committed
1819
1820
        If set to ``False``, the layer will not learn additive biases.
    bias_init: Initializer, default = ``flax.linen.initializers.zeros``
1821
1822
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1823
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1824
    apply_residual_connection_post_layernorm: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1825
        If set to ``True``, residual connections are taken from the output
1826
1827
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1828
        If set to ``True``, layer normalization is applied on the output side,
1829
1830
1831
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1832
1833
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1834
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1835
        If set to TransformerLayerType.DECODER, an additional cross-attention block
Paweł Gadziński's avatar
Paweł Gadziński committed
1836
        is added after self-attention.this can be used for structures like T5
1837
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1838
    self_attn_mask_type: str, default = 'causal'
1839
1840
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1841
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
1842
1843
1844

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
1845
        * ``no_mask``: No attention mask is applied. This means the self attention will consider the
1846
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1847
1848
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
1849
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1850
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
1851
1852
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
1853
1854
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
1855

Paweł Gadziński's avatar
Paweł Gadziński committed
1856
        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
1857

1858
1859
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1860
        Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
1861
        When default is present, the type is automatically decided by the MHA's bias parameter.
Paweł Gadziński's avatar
Paweł Gadziński committed
1862
        Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
1863
    enable_relative_embedding: bool, default = True
1864
        Whether to enable relative embedding as shifting of attention logits.
1865
    relative_embedding: flax.linen.Module, default = None
1866
        The module for relative embedding execution, only used when
Paweł Gadziński's avatar
Paweł Gadziński committed
1867
        :attr:`enable_relative_embedding=True`. Default is ``None``, which will create
1868
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1869
        Default: ``RelativePositionBiases( num_buckets=32, max_distance=128,
1870
1871
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
Paweł Gadziński's avatar
Paweł Gadziński committed
1872
        name='relpos_bias')``
1873
1874
1875
1876
1877
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1878
    rotary_pos_emb_group_method: str, default = 'consecutive'
1879
        Indicate the method to couple the coordinates. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1880
1881
        ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with
1882
        :math:`i + 1`.
1883
1884
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1885
1886
        ``['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
        'exclude_output_proj', 'exclude_mlp']``
1887
1888
1889
1890
1891
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
Paweł Gadziński's avatar
Paweł Gadziński committed
1892
        :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
1893
1894
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1895
1896
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1897
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
1898
        Softmax type as described in the paper
1899
1900
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
1923
        Only supported for fused attention backend.
1924
1925
1926

    Optimization parameters
    -----------------------
1927
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1928
        The data type used to allocate the initial parameters.
1929
    drop_path: float, default = 0.0
1930
        When > 0.0, applies stochastic depth per sample in the main
1931
1932
        path of the residual block.
    fuse_qkv_params: bool, default = True
Paweł Gadziński's avatar
Paweł Gadziński committed
1933
        If set to ``True``, ``TransformerLayer`` module exposes a single fused
1934
1935
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1936
    transpose_batch_sequence: bool, default = False
1937
        Indicate whether the input tensors were switched axis of batch
Paweł Gadziński's avatar
Paweł Gadziński committed
1938
1939
        and sequence length dimension. if set to ``True``, the input tensors
        should be in ``(seqlen, batch, hidden)``, otherwise ``(batch, seqlen, hidden)``.
1940
    scale_attn_logits: bool, default = False
1941
        Indicate whether to scale attention logits.
Paweł Gadziński's avatar
Paweł Gadziński committed
1942
1943
1944
1945
        if set to ``True``, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
        else :math:`Q \cdot K^T`
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\sqrt{head\_dim}`
1946
1947
1948
1949
1950
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1951
    num_gqa_groups: Optional[int] = None
1952
    layernorm_type: str = "layernorm"
1953
    layernorm_epsilon: float = 1e-6
1954
    zero_centered_gamma: bool = False
1955
1956
1957
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1958
    intermediate_dropout: float = 0.0
1959
    intermediate_dropout_dims: Sequence[int] = ()
1960
    dropout_rng_name: str = "dropout"
1961
1962
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1963
    mlp_activations: Sequence[str] = ("gelu",)
1964
    mlp_activation_params: dict = None
1965
1966
1967
1968
1969
1970
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1971
    self_attn_mask_type: str = "causal"
1972
    self_attn_bias_type: Optional[str] = None
1973
1974
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1975
1976
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1977
1978
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1979
1980
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1981
1982
1983
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1984
    transpose_batch_sequence: bool = False
1985
    enable_sequence_parallel: bool = False
1986
1987
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1988
    window_size: Optional[Tuple[int, int]] = None
1989
    softmax_type: str = "vanilla"
1990
1991
1992

    def __post_init__(self):
        if self.mha_kernel_init is None:
1993
            self.mha_kernel_init = nn.initializers.variance_scaling(
1994
                1.0, "fan_in", "normal", dtype=self.dtype
1995
            )
1996
        if self.mlp_kernel_init is None:
1997
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1998
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1999
            )
zlsh80826's avatar
zlsh80826 committed
2000
2001
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
2002
2003
2004
        super().__post_init__()

    @nn.compact
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
2015
2016
2017
2018
2019
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
2020
        inputs: jax.numpy.ndarray
2021
            Input tensor.
2022
        encoded: jax.numpy.ndarray, default = None
2023
2024
2025
2026
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
2027
            :attr:`True` means mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
2028
            Ignored when :attr:`self.self_attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
2029
        encoder_decoder_mask: jax.numpy.ndarray, default = None
2030
2031
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
2032
            :attr:`True` means mask out the corresponding values.
2033
        deterministic: bool, default = False
2034
            Disable dropout layers if set to True.
2035
        decode: bool, default = False
2036
2037
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
2038
        max_decode_length: bool, default = None
2039
2040
2041
2042
2043
2044
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
2045
        outputs: jax.numpy.ndarray
2046
            Output tensors.
2047
        """
2048

2049
        input_dtype = inputs.dtype
2050
2051
2052
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
2053

2054
2055
2056
2057
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
2058

2059
2060
2061
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
2062
2063
2064
2065
2066
2067

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

2068
2069
2070
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

2071
2072
2073
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
2074
2075
2076
2077
2078

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

2079
2080
2081
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
2082
2083
2084
2085
2086
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
2087
2088
2089
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
2090
2091
                    name="relpos_bias",
                )
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
2109
            mha_name = "attention"
2110
        else:
2111
            mha_name = "self_attention"
2112

2113
        inputs = with_sharding_constraint_by_logical_axes(
2114
2115
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2116

2117
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
2118
2119
2120
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
2121
2122
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2123
            num_gqa_groups=self.num_gqa_groups,
2124
            transpose_batch_sequence=self.transpose_batch_sequence,
2125
            enable_sequence_parallel=self.enable_sequence_parallel,
2126
            attention_dropout=self.attention_dropout,
2127
2128
2129
2130
2131
2132
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
2133
            zero_centered_gamma=self.zero_centered_gamma,
2134
2135
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
2136
            attn_mask_type=self.self_attn_mask_type,
2137
            attn_bias_type=self.self_attn_bias_type,
2138
2139
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2140
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2141
2142
2143
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2144
            fuse_qkv_params=self.fuse_qkv_params,
2145
2146
2147
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2148
            name=mha_name,
2149
            window_size=self.window_size,
2150
            softmax_type=self.softmax_type,
2151
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
2152
2153
2154
2155
2156

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
2157
                assert -x_shape_len <= dims < x_shape_len
2158

2159
2160
2161
2162
2163
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2164

2165
        x = with_sharding_constraint_by_logical_axes(
2166
2167
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2168
        residual = with_sharding_constraint_by_logical_axes(
2169
2170
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2171

2172
2173
2174
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
2175
2176
2177
2178
2179
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2180
2181
2182
2183
2184

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2185
2186
2187
2188
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
2189
2190
2191
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
2192

2193
            x = with_sharding_constraint_by_logical_axes(
2194
2195
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2196

2197
2198
2199
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
2200
2201
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2202
                num_gqa_groups=self.num_gqa_groups,
2203
                transpose_batch_sequence=self.transpose_batch_sequence,
2204
                enable_sequence_parallel=self.enable_sequence_parallel,
2205
                attention_dropout=self.attention_dropout,
2206
2207
2208
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
2209
                zero_centered_gamma=self.zero_centered_gamma,
2210
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
2211
2212
2213
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
2214
2215
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2216
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2217
2218
2219
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2220
2221
2222
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
2223
                fuse_qkv_params=self.fuse_qkv_params,
2224
2225
2226
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
2227
                name="encoder_decoder_attention",
2228
                window_size=self.window_size,
2229
                softmax_type=self.softmax_type,
2230
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
2231
2232

            y = with_sharding_constraint_by_logical_axes(
2233
2234
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2235
            residual = with_sharding_constraint_by_logical_axes(
2236
2237
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2238

2239
            y = hidden_dropout(y, deterministic)
2240
2241
2242
2243
2244

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

2245
2246
            mlp_input = y + residual

2247
        mlp_input = with_sharding_constraint_by_logical_axes(
2248
2249
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2250

2251
2252
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2253
2254
2255
2256
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2257
            zero_centered_gamma=self.zero_centered_gamma,
2258
2259
2260
2261
            epsilon=self.layernorm_epsilon,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2262
            activation_params=self.mlp_activation_params,
2263
2264
2265
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2266
            dtype=self.dtype,
2267
2268
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2269
            kernel_init=self.mlp_kernel_init,
2270
2271
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2272
2273
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2274
2275
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2276
2277
2278
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2279
2280
2281
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2282
            transpose_batch_sequence=self.transpose_batch_sequence,
2283
            name="mlp",
2284
2285
2286
2287
2288
2289
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2290
        z = with_sharding_constraint_by_logical_axes(
2291
2292
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2293
        residual = with_sharding_constraint_by_logical_axes(
2294
2295
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2296

2297
2298
2299
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2300
2301
2302
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2303
2304
2305
        z = z + residual

        if self.output_layernorm:
2306
            z = with_sharding_constraint_by_logical_axes(
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2318
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2319
        return z