transformer.py 100 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
18
from flax.linen.attention import combine_masks
19
20
21
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
27
28
29
30
31
32
from ..attention import (
    AttnBiasType,
    AttnMaskType,
    AttnSoftmaxType,
    QKVLayout,
    SequenceDescriptor,
)
33
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
34
from ..attention import fused_attn
35
from ..attention import CPStrategy
36
from ..softmax import SoftmaxFusionType
37
38
39
40
41
42
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
43
44
45
46
47

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
48
49
50
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
51
52
53
54
55
56
57
58
59
60
61
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


Phuong Nguyen's avatar
Phuong Nguyen committed
62
# TODO(Phuong): move this function to sharding.py
63
64
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
65
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
66
67
68
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
70
71
72
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
73
74

    .. warning::
75
        Please make sure ShardingResource is set via autocast before calling this function.
76

Ming-Xu Huang's avatar
Ming-Xu Huang committed
77
78
79
80
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

81
82
    Parameters
    ----------
83
    rules: Sequence[Tuple[str, Union[str, None]]]
84
85
86
87
        the base Flax logical axis rules to extend.

    Returns
    -------
88
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
89
90
91
92
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
93
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
94
95
        key = item[0]
        val = item[1]
96
97
98
99
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
100
101
102
103
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
104
105

    extended_rules = [*rules]
106
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
107
108
109
        key = item[0]
        val = item[1]
        if key in rules_map:
110
111
112
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
113
                f" rules, but {val} in TE's rules."
114
            )
115
116
117
118
119
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


120
121
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
122
123
124
125
126
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
127
    transpose_batch_sequence: bool = False
128
    window_size: Optional[Tuple[int, int]] = None
129
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
130
131

    @nn.compact
132
133
134
135
136
137
138
139
140
141
142
143
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
144
        batch_dim = 1 if self.transpose_batch_sequence else 0
145
146
147
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
148
        sequence_dim = 0 if self.transpose_batch_sequence else 1
149
150
151
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
152

153
154
        input_dtype = query.dtype

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        # Infer number of attention heads from query shape
        # query shape: [..., h, d] where h is num_attention_heads
        num_attention_heads = query.shape[-2]

        # Initialize softmax_offset for learnable softmax
        # Note: OFF_BY_ONE_SOFTMAX is handled internally by the Softmax module
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

171
172
173
174
175
176
177
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
178
179
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
180
181
182
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
183
        is_gqa = h_q != h_kv
184

185
        if is_gqa:
186
187
188
189
190
191
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
192
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
193
            else:
194
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
195
        else:
196
            if is_gqa:
197
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
198
            else:
199
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
200

201
        attn_weights = checkpoint_name(attn_weights, "logits")
202

203
        if is_gqa:
204
205
206
207
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

208
        # (b, h, q, k): Last two axes are always replicated
209
        attn_weights = with_sharding_constraint_by_logical_axes(
210
            attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
211
        )
212
213
214
215
216

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
217
            fused_scale_factor = 1.0
218
        else:
219
220
221
222
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias
223
                bias = None
224

225
        def apply_swa_mask(original_mask: Array) -> Array:
226
            """Apply the sliding window mask to a given mask"""
227
            batch = original_mask.shape[0]
228
229
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
230
231
232
233
234
235
236
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
237
238
            return new_mask

239
240
        def convert_to_softmax_fusion_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxFusionType"""
241
242
243
244
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
245
                mask = None
246
            if mask is not None:
247
                mask = apply_swa_mask(mask)
248
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
249
            if mask is not None:
250
                return SoftmaxFusionType.SCALED_MASKED, mask
251
            if attn_mask_type is AttnMaskType.CAUSAL_MASK:
252
                return SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, mask
253
            if attn_mask_type is AttnMaskType.NO_MASK:
254
                return SoftmaxFusionType.SCALED, mask
255
256
257
258
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
259

260
        softmax_fusion_type, mask = convert_to_softmax_fusion_type(self.attn_mask_type, mask)
261

262
263
264
265
266
        attn_weights = Softmax(
            softmax_fusion_type=softmax_fusion_type,
            softmax_type=self.softmax_type,
            scale_factor=fused_scale_factor,
        )(attn_weights, mask, bias, softmax_offset=softmax_offset).astype(input_dtype)
267

268
        if is_gqa:
269
270
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

271
        if not deterministic and self.attention_dropout > 0.0:
272
273
274
275
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
276
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
277
278
            attn_weights = attn_weights * multiplier

279
280
281
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
282
283
        if self.transpose_batch_sequence:
            if is_gqa:
284
285
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
286
287

        if is_gqa:
288
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
289

290
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
291
292


293
294
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
295
296
297
298
299
300
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
301
    window_size: Optional[Tuple[int, int]] = None
302
    max_segments_per_seq: Optional[int] = 1
303
304
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
305
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
306
    context_checkpoint_name: str = "context"
307
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
308
309

    @nn.compact
310
311
312
313
314
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
315
        sequence_descriptor: Optional[SequenceDescriptor] = None,
316
317
318
319
320
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
321
322
323
324
325
326
327
328
329
330
331

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

332
333
334
335
336
337
338
339
340
341
342
        num_attention_heads = query.shape[-2]
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

343
        if self.qkv_layout.is_qkvpacked():
344
345
346
347
348
349
350
351
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
352
353
            x = fused_attn(
                (qkv_packed,),
354
                bias,
355
                sequence_descriptor,
356
357
358
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
359
                qkv_layout=self.qkv_layout,
360
                softmax_type=self.softmax_type,
361
362
363
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
364
                window_size=self.window_size,
365
                max_segments_per_seq=self.max_segments_per_seq,
366
367
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
368
                context_parallel_strategy=self.context_parallel_strategy,
369
                context_checkpoint_name=self.context_checkpoint_name,
370
                softmax_offset=softmax_offset,
371
            )
372
        elif self.qkv_layout.is_kvpacked():
373
374
375
376
377
378
379
380
381
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
382
383
            x = fused_attn(
                (query, kv_packed),
384
                bias,
385
                sequence_descriptor,
386
387
388
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
389
                qkv_layout=self.qkv_layout,
390
                softmax_type=self.softmax_type,
391
392
393
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
394
                window_size=self.window_size,
395
                max_segments_per_seq=self.max_segments_per_seq,
396
397
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
398
                context_parallel_strategy=self.context_parallel_strategy,
399
                context_checkpoint_name=self.context_checkpoint_name,
400
                softmax_offset=softmax_offset,
401
            )
402
        elif self.qkv_layout.is_separate():
403
404
405
406
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
407
            x = fused_attn(
408
                (query, key, value),
409
                bias,
410
                sequence_descriptor,
411
412
413
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
414
                qkv_layout=self.qkv_layout,
415
                softmax_type=self.softmax_type,
416
417
418
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
419
                window_size=self.window_size,
420
                max_segments_per_seq=self.max_segments_per_seq,
421
422
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
423
                context_parallel_strategy=self.context_parallel_strategy,
424
                context_checkpoint_name=self.context_checkpoint_name,
425
                softmax_offset=softmax_offset,
426
            )
427
428
429
430
431
432
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

433
        assert x.dtype == query.dtype
434
435
436
        return x


437
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

453
454
455
456
        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
          attention kernel is not available on the system, a warning will be issued, and the module
          will automatically fall back to the unfused backend.
457

458
459
460
461
462
463
464
465
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

466
467
468
469
470
471
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
Paweł Gadziński's avatar
Paweł Gadziński committed
472
    num_gqa_groups: int, default = None
473
474
475
476
477
478
479
480
481
482
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
483
484
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
485
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
486
487
488

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
489
        * ``no_mask``: No attention mask is applied. This means the attention will consider the
490
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
491
492
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
493
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
494
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
495
496
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
497
498
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
499

Paweł Gadziński's avatar
Paweł Gadziński committed
500
        |
501

Paweł Gadziński's avatar
Paweł Gadziński committed
502
        .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
503

Paweł Gadziński's avatar
Paweł Gadziński committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        |

        .. note:: THD format only supports ``'padding'`` or ``'causal_padding'`` mask type.

        |

        .. table::
            :widths: auto

            ================== ============ ========== ==============================
            attn_mask_type     mask/sd      SWA        softmax type
            ================== ============ ========== ==============================
            no_mask            None         None       SCALED
            causal             None         None       SCALED_UPPER_TRIANG_MASKED
            causal             None         Yes        SCALED_MASKED
            padding            Required     Yes/No     SCALED_MASKED
            padding_causal     Required     Yes/No     SCALED_MASKED
            ================== ============ ========== ==============================

        where sd stands for sequence_descriptor.
524

525
    attn_bias_type: Optional[str], default = None
526
        Type of the attention bias passed in the attention.
527
528
529
530
531
532
533
534
535
536
537
538
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
539
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
540
541
542
543
544
545

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
546
547
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
548
549
550
551
552
553
554
555
556
557
558
559

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
560
561
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
562
        Indicate whether the input tensors were switched axis of batch
563
        and sequence length dimension. If set to True, the input tensors
564
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
565
566
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
567
568
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
Paweł Gadziński's avatar
Paweł Gadziński committed
569
570
571
572
573
574
575
576
    context_parallel_causal_load_balanced: bool
        Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis: str
        The name of the context parallel axis.
    context_parallel_strategy: CPStrategy
        The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
    context_checkpoint_name: str
        The name of the context checkpoint in the forward pass of fused attention.
577
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
578
        Softmax type as described in the paper
579
580
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
603
604
605

    Optimization parameters
    -----------------------
606
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
607
        The data type used to allocate the initial parameters.
608
    """
609

610
611
612
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
613
614
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
615
616
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
617
    dropout_rng_name: str = "dropout"
618
    float32_logits: bool = False
619
    qkv_layout: str = "bshd_bshd_bshd"
620
    scale_factor: Optional[float] = None
621
    transpose_batch_sequence: bool | None = None
622
    window_size: Optional[Tuple[int, int]] = None
623
    max_segments_per_seq: Optional[int] = 1
624
625
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
626
    context_parallel_strategy: str = "DEFAULT"
627
    context_checkpoint_name: str = "context"
628
    softmax_type: str = "vanilla"
629

630
631
632
633
634
635
636
637
638
639
640
    def __post_init__(self):
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in DotProductAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
        super().__post_init__()

641
    @nn.compact
642
643
644
645
646
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
647
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
648
649
650
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
651
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
652
    ) -> Array:
653
654
655
656
657
658
659
660
661
662
663
664
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
665
            Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
666
667
668
669
670
671
672
673
674
675
676
677
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
678
        input_dtype = query.dtype
679

680
681
682
683
684
685
686
687
688
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

689
690
691
692
693
694
695
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
696
        softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
697
698
699
700
701
702
703
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

704
705
        # Use fused attn (if kernel check below passes) by default
        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
706
707
708
709
710
711
712

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]
713
714
715
716
717
718
        if qkv_layout.is_separate():
            head_dim_qk = query.shape[-1]
            head_dim_v = value.shape[-1]
        else:
            head_dim_qk = self.head_dim
            head_dim_v = self.head_dim
719

720
        has_fused_attn_kernel = is_fused_attn_kernel_available(
721
722
            # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
            not deterministic,
723
724
725
726
727
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
728
            softmax_type,
729
730
731
732
733
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
734
735
            head_dim_qk,
            head_dim_v,
736
            self.window_size,
737
        )
738

739
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
740
741

        if enable_fused_attn and not has_fused_attn_kernel:
742
743
744
745
746
747
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
748
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
749
            )
750
751

        dropout_rng = None
752
        if not deterministic and self.attention_dropout > 0.0:
753
754
755
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
756
            scale_factor = 1.0 / sqrt(head_dim_qk)
757
758
759
760
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
        # case-insensitive mapping for context parallel strategy
        cp_strategy_map = {
            "DEFAULT": CPStrategy.DEFAULT,
            "ALL_GATHER": CPStrategy.ALL_GATHER,
            "ALLGATHER": CPStrategy.ALL_GATHER,  # Alternative spelling
            "RING": CPStrategy.RING,
        }

        strategy_key = self.context_parallel_strategy.upper()
        if strategy_key in cp_strategy_map:
            context_parallel_strategy = cp_strategy_map[strategy_key]
        else:
            valid_strategies = list(cp_strategy_map.keys())
            raise ValueError(
                f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
                f"Valid options are: {valid_strategies} (case insensitive)"
            )

779
780
        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
781
            if qkv_layout.is_qkvpacked():
782
                query, key, value = jnp.split(query, [1, 2], axis=-3)
783
784
785
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
786
            elif qkv_layout.is_kvpacked():
787
788
789
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
790
791
                assert qkv_layout.is_separate()

792
793
794
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
795

796
797
798
799
800
801
802
803
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
804
                window_size=self.window_size,
805
                softmax_type=softmax_type,
806
807
808
809
810
811
812
813
814
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
815
816
817
818
819
820
821
822
823
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
824
                window_size=self.window_size,
825
                max_segments_per_seq=self.max_segments_per_seq,
826
827
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
828
                context_parallel_strategy=context_parallel_strategy,
829
                context_checkpoint_name=self.context_checkpoint_name,
830
                softmax_type=softmax_type,
831
832
833
834
835
836
837
838
839
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
840
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
841
        return x
842
843


844
845
846
847
848
849
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
850
851
    """
    Rotary Positional Embedding
Paweł Gadziński's avatar
Paweł Gadziński committed
852
    x should be of shape
853
854
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
855
    """
856
857
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
858
859
860
    min_window = windows[0]
    max_window = windows[1]

861
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
862
    time_scales = min_window * (max_window / min_window) ** fraction
863
864
865
866
867
868
869
870
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

871
872
873
874
875
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
876

877
878
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
879

880
        x1, x2 = jnp.split(x, 2, axis=-1)
881
882
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
883

884
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
908
909
910
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
911
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
912
        )
913
914
915
916
917

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

918
    if group_method == "alternate":
919
920
        return alternate_impl()
    return consecutive_impl()
921
922


923
class LoRAScope:  # pylint: disable=too-few-public-methods
924
925
926
927
928
929
930
931
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
932
933
934
935
936
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
937
938
939
940


def _canonicalize_lora_scope(scope):

941
942
943
944
945
946
947
948
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
949
950
951
952
953
954

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
955
956
957
958
959
960
961
962
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


979
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
980
981
982
983
984
985
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
986
    head_dim: int
987
        The hidden dimension of each attention head.
988
989
    num_attention_heads: int
        The number of attention heads.
Paweł Gadziński's avatar
Paweł Gadziński committed
990
    num_gqa_groups: int, default = None
991
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
992
993
994
995
996
997
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
998
999
1000
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
1001
1002
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
1003
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
1004
1005
1006

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
1007
        * ``no_mask``: No attention mask is applied. This means the attention will consider the
1008
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1009
1010
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
1011
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1012
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
1013
1014
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
1015
1016
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
1017

Paweł Gadziński's avatar
Paweł Gadziński committed
1018
        .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
1019

1020
1021
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1022
        Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
1023
        When default is present, the type is automatically decided by the MHA's bias parameter.
Paweł Gadziński's avatar
Paweł Gadziński committed
1024
        Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
1025
    dropout_rng_name: str, default = 'dropout'
1026
1027
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
1028
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1029
        Indicate the type of layer normalization.
1030
    layernorm_epsilon: float, default = 1e-6
1031
        A value added to the denominator of layer normalization for numerical stability.
1032
    zero_centered_gamma: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1033
        If set to ``True``, the LayerNorm formula changes to
1034
1035

        .. math::
Paweł Gadziński's avatar
Paweł Gadziński committed
1036
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
1037
1038
            (1 + \gamma) + \beta

Paweł Gadziński's avatar
Paweł Gadziński committed
1039
        This parameter is only applicable for ``'layernorm'``.
1040
    kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1041
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
1042
        Used for initializing the QKV and output projection weights.
Paweł Gadziński's avatar
Paweł Gadziński committed
1043
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1044
    use_bias: bool, default = False
1045
        Indicate whether or not to enable bias shifting for QKV and output projections.
Paweł Gadziński's avatar
Paweł Gadziński committed
1046
1047
        If set to ``False``, the layer will not learn additive biases.
    bias_init: Initializer, default = ``flax.linen.initializers.zeros``
1048
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1049
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1050
    input_layernorm: bool, default = True
Paweł Gadziński's avatar
Paweł Gadziński committed
1051
        If set to ``False``, layer normalization to the input is not applied.
1052
    return_layernorm_output: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1053
        If set to ``True``, output of layernorm is returned from the forward together with the output
1054
1055
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
1056
1057
1058
1059
1060
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1061
1062
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1063
1064
        ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`.
1065
1066
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1067
        ``['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']``
1068
1069
1070
1071
1072
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
Paweł Gadziński's avatar
Paweł Gadziński committed
1073
        :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
1074
1075
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1076
1077
1078
1079
1080
1081
1082
1083
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
1084
1085
1086

    Optimization parameters
    -----------------------
1087
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1088
        The data type used to allocate the initial parameters.
1089
    fuse_qkv_params: bool, default = True
1090
        If set to True, this module exposes a single fused
1091
1092
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1093
1094
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
1095
        Indicate whether the input tensors were switched axis of batch
1096
1097
1098
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1099
        Indicate whether to scale attention logits.
Paweł Gadziński's avatar
Paweł Gadziński committed
1100
1101
        If set to True, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
        else :math:`Q \cdot K^T`
1102
1103
1104
1105
1106
1107
1108
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
1109
1110
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1111
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
1112
        Softmax type as described in the paper
1113
1114
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
1137
1138
1139
    """

    head_dim: int
1140
1141
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
1142
1143
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
1144
    input_layernorm: bool = True
1145
1146
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
1147
    return_layernorm_output: bool = False
1148
    zero_centered_gamma: bool = False
1149
1150
1151
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
1152
    attn_mask_type: str = "causal"
1153
    attn_bias_type: Optional[str] = None
1154
1155
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1156
1157
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1158
1159
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1160
    dtype: DType = jnp.float32
1161
    fuse_qkv_params: bool = True
1162
    transpose_batch_sequence: bool | None = None
1163
    enable_sequence_parallel: bool = False
1164
1165
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1166
    float32_logits: bool = False
1167
    window_size: Optional[Tuple[int, int]] = None
1168
    softmax_type: str = "vanilla"
1169
1170
1171
1172
1173
1174
1175

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
1176
1177

    def __post_init__(self):
1178
1179
1180
1181
1182
1183
1184
1185
1186
        # Deal with changed defaults in API
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in MultiHeadAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
1187
1188
1189
1190
1191
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1192
1193
1194
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1195
1196
1197
1198
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1199
1200
1201
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1202
1203
1204
1205
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1206
1207
                DeprecationWarning,
            )
1208
1209
1210
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1211
1212
1213
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1214
1215
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1216
1217
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1218

1219
        if self.kernel_init is None:
1220
            self.kernel_init = nn.initializers.variance_scaling(
1221
                1.0, "fan_in", "normal", dtype=self.dtype
1222
            )
zlsh80826's avatar
zlsh80826 committed
1223
        if self.num_gqa_groups is None:
1224
            self.num_gqa_groups = self.num_attention_heads
1225
1226
1227
        super().__post_init__()

    @nn.compact
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1238
1239
1240
1241
1242
1243
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1244
        inputs_q: jax.numpy.ndarray
1245
            Input tensor for query projection.
1246
        inputs_kv: jax.numpy.ndarray
1247
            Input tensor for key/value projection.
1248
1249
1250
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
1251
            Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
1252
1253
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1254
        *
1255
        decode: bool, default = False
1256
            Indicate whether to prepare and use an autoregressive cache.
1257
        deterministic: bool, default = False
1258
1259
1260
1261
            Disable dropout layers if set to True.

        Returns
        -------
1262
        outputs: jax.numpy.ndarray
1263
1264
            Output tensors.
        """
1265

1266
1267
1268
1269
1270
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1271
        def query_init(*args):
1272
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1315
1316
1317
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1318

1319
1320
1321
1322
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1323
1324
1325
1326
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1327
1328
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1329
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1330
            if is_qkvpack:
1331
                qkv_proj, ln_out = LayerNormDenseGeneral(
1332
                    enable_layernorm=self.input_layernorm,
1333
                    layernorm_type=self.layernorm_type,
1334
                    zero_centered_gamma=self.zero_centered_gamma,
1335
1336
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1337
1338
                    features=(3, self.num_attention_heads * self.head_dim),
                    return_layernorm_output=self.return_layernorm_output,
1339
1340
1341
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1342
1343
1344
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1345
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1346
1347
1348
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1349
1350
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1351
                    transpose_batch_sequence=self.transpose_batch_sequence,
1352
1353
1354
1355
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1356
                qkv_layout = QKVLayout.BS3HD
1357
1358
            else:
                query, ln_out = LayerNormDenseGeneral(
1359
                    enable_layernorm=self.input_layernorm,
1360
                    layernorm_type=self.layernorm_type,
1361
                    zero_centered_gamma=self.zero_centered_gamma,
1362
1363
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1364
1365
                    features=self.num_attention_heads * self.head_dim,
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1366
1367
1368
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1369
1370
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1371
                    bias_axes=(W_TP_AXES,),
1372
1373
1374
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1375
1376
                    dtype=self.dtype,
                    kernel_init=query_init,
1377
1378
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1379
                    transpose_batch_sequence=self.transpose_batch_sequence,
1380
1381
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1382
1383
1384
1385
1386

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1398
                    transpose_batch_sequence=self.transpose_batch_sequence,
1399
1400
1401
1402
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1403
                qkv_layout = QKVLayout.BSHD_BS2HD
1404
1405
1406
1407
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1408
                features=self.num_gqa_groups * self.head_dim,
1409
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1410
1411
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1412
                bias_axes=(W_TP_AXES,),
1413
1414
1415
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1416
1417
                dtype=self.dtype,
            )
1418
            query, ln_out = LayerNormDenseGeneral(
1419
                enable_layernorm=self.input_layernorm,
1420
                layernorm_type=self.layernorm_type,
1421
                zero_centered_gamma=self.zero_centered_gamma,
1422
1423
                epsilon=self.layernorm_epsilon,
                axis=-1,
1424
                features=self.num_attention_heads * self.head_dim,
1425
                return_layernorm_output=True,
1426
1427
1428
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1429
1430
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1431
                bias_axes=(W_TP_AXES,),
1432
1433
1434
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1435
1436
                dtype=self.dtype,
                kernel_init=query_init,
1437
1438
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1439
                transpose_batch_sequence=self.transpose_batch_sequence,
1440
1441
                name="query",
            )(inputs_q)
1442

1443
            if is_self_attn:
1444
1445
1446
                assert ln_out is not None
                inputs_kv = ln_out

1447
            query = query.astype(input_dtype)
1448
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1449
            key = key.astype(input_dtype)
1450
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1451
            value = value.astype(input_dtype)
1452
1453
1454
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1455
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1456

1457
        if self.enable_rotary_pos_emb:
1458
1459
1460
1461
1462
1463
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1464

1465
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1466
1467
1468
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1481
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1482

1483
1484
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1485
1486
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1487
1488

        if decode:
1489
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1490
1491
1492
1493
1494
1495
1496
1497
1498
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1499
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1500
                if self.transpose_batch_sequence:
1501
1502
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1503
1504
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1505
1506
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1507
                    one_hot_indices_shape = (1, length, 1, 1)
1508
1509
1510
1511

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1512
1513
1514
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1515

1516
                cur_index = cache_index.value.astype(jnp.int32)
1517
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1518
1519
1520
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1521
1522
1523
1524
1525
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1526
1527
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1528
1529

                if bias is not None:
1530
1531
1532
1533
1534
1535
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1536

1537
1538
1539
1540
1541
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1542
1543
1544
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1556
        else:
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1567
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1581
            window_size=self.window_size,
1582
            softmax_type=self.softmax_type,
1583
        )(*dpa_args, mask, bias, deterministic=deterministic)
1584
1585
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1586
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1587
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1588

1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1604

1605
1606
1607
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1608
        return out, ln_out
1609
1610


1611
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1612
1613
1614
1615
1616
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1617
    num_buckets: int
1618
        The number of buckets to bucket distances between key and query positions into.
1619
    max_distance: int
1620
        The maximum distance before everything is lumped into the last
1621
        distance bucket.
1622
    num_attention_heads: int
1623
        Number of attention heads in the transformer layer.
1624
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1625
        Used for initializing relative embedding tables.
1626
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1627
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1628
1629
1630

    Optimization parameters
    -----------------------
1631
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1632
        The data type used to allocate the initial parameters.
1633
    """
1634

1635
1636
1637
1638
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1639
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1640
1641
1642
1643
1644
1645
1646
1647
1648
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1649
        q_seqlen: int
1650
            The sequence length of query.
1651
        k_seqlen: int
1652
            The sequence length of key.
1653
        bidirectional: bool, default = True
1654
            Indicate whether to allow positive memory-query relative position
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1681
1682
1683
1684
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1685
1686
1687
1688
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
1689
        relative_attention_bias = self.param(
1690
            "rel_embedding",
1691
            nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
1692
            (self.num_attention_heads, self.num_buckets),
1693
            self.dtype,
1694
        )
1695
1696
1697
1698
1699
1700

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1701
1702
1703
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1704
1705
1706
1707
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1718

1719
1720
1721
1722
    ENCODER = "encoder"
    DECODER = "decoder"


1723
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1724
1725
1726
1727
1728
1729
1730
1731
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1732
        The hidden size of each input sample.
1733
    mlp_hidden_size: int, default = 2048
1734
        Intermediate size to which input samples are projected.
1735
    num_attention_heads: int, default = 8
1736
        Number of attention heads in the transformer layer.
Paweł Gadziński's avatar
Paweł Gadziński committed
1737
    num_gqa_groups: int, default = None
zlsh80826's avatar
zlsh80826 committed
1738
1739
1740
1741
1742
1743
1744
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1745
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1746
        Indicate the type of layer normalization.
1747
    layernorm_epsilon: float, default = 1e-6
1748
        A value added to the denominator of layer normalization for numerical stability.
1749
    zero_centered_gamma: bool, default = False
1750
1751
1752
1753
1754
1755
1756
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1757
    hidden_dropout: float, default = 0.1
1758
        Dropout probability for the dropout op after FC2 layer.
1759
    hidden_dropout_dims: Sequence[int], default = ()
1760
        Dimensions that will share the same dropout mask for hidden
1761
    attention_dropout: float, default = 0.1
1762
        Dropout probability for the dropout op during multi-head attention.
1763
    intermediate_dropout: float, default = 0.0
1764
1765
1766
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1767
    dropout_rng_name: str, default = 'dropout'
1768
1769
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1770
    mha_kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1771
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
1772
        Used for initializing weights of QKV and Output projection weights.
Paweł Gadziński's avatar
Paweł Gadziński committed
1773
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1774
    mlp_kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1775
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')``
1776
        Used for initializing weights of FC1 and FC2 layers.
Paweł Gadziński's avatar
Paweł Gadziński committed
1777
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1778
    mlp_activations: Sequence[str], default = ('gelu', )
1779
        The sequence of activation functions to apply after the first linear transformation.
1780
        Each activation has its own transformation layer.
1781
    mlp_activation_params: dict = None
Paweł Gadziński's avatar
Paweł Gadziński committed
1782
1783
         This is only used when ``('clamped_silu', 'clamped_linear')`` is in :attr:`mlp_activations`. At the moment
        ``ClampedSwiglu`` is the only activation that requires parameters.
1784
    use_bias: bool, default = False
1785
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
Paweł Gadziński's avatar
Paweł Gadziński committed
1786
1787
        If set to ``False``, the layer will not learn additive biases.
    bias_init: Initializer, default = ``flax.linen.initializers.zeros``
1788
1789
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1790
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1791
    apply_residual_connection_post_layernorm: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1792
        If set to ``True``, residual connections are taken from the output
1793
1794
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1795
        If set to ``True``, layer normalization is applied on the output side,
1796
1797
1798
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1799
1800
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1801
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1802
        If set to TransformerLayerType.DECODER, an additional cross-attention block
Paweł Gadziński's avatar
Paweł Gadziński committed
1803
        is added after self-attention.this can be used for structures like T5
1804
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1805
    self_attn_mask_type: str, default = 'causal'
1806
1807
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1808
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
1809
1810
1811

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
1812
        * ``no_mask``: No attention mask is applied. This means the self attention will consider the
1813
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1814
1815
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
1816
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1817
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
1818
1819
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
1820
1821
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
1822

Paweł Gadziński's avatar
Paweł Gadziński committed
1823
        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
1824

1825
1826
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1827
        Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
1828
        When default is present, the type is automatically decided by the MHA's bias parameter.
Paweł Gadziński's avatar
Paweł Gadziński committed
1829
        Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
1830
    enable_relative_embedding: bool, default = True
1831
        Whether to enable relative embedding as shifting of attention logits.
1832
    relative_embedding: flax.linen.Module, default = None
1833
        The module for relative embedding execution, only used when
Paweł Gadziński's avatar
Paweł Gadziński committed
1834
        :attr:`enable_relative_embedding=True`. Default is ``None``, which will create
1835
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1836
        Default: ``RelativePositionBiases( num_buckets=32, max_distance=128,
1837
1838
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
Paweł Gadziński's avatar
Paweł Gadziński committed
1839
        name='relpos_bias')``
1840
1841
1842
1843
1844
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1845
    rotary_pos_emb_group_method: str, default = 'consecutive'
1846
        Indicate the method to couple the coordinates. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1847
1848
        ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with
1849
        :math:`i + 1`.
1850
1851
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1852
1853
        ``['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
        'exclude_output_proj', 'exclude_mlp']``
1854
1855
1856
1857
1858
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
Paweł Gadziński's avatar
Paweł Gadziński committed
1859
        :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
1860
1861
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1862
1863
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1864
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
1865
        Softmax type as described in the paper
1866
1867
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
1890
        Only supported for fused attention backend.
1891
1892
1893

    Optimization parameters
    -----------------------
1894
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1895
        The data type used to allocate the initial parameters.
1896
    drop_path: float, default = 0.0
1897
        When > 0.0, applies stochastic depth per sample in the main
1898
1899
        path of the residual block.
    fuse_qkv_params: bool, default = True
Paweł Gadziński's avatar
Paweł Gadziński committed
1900
        If set to ``True``, ``TransformerLayer`` module exposes a single fused
1901
1902
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1903
    transpose_batch_sequence: bool, default = False
1904
        Indicate whether the input tensors were switched axis of batch
Paweł Gadziński's avatar
Paweł Gadziński committed
1905
1906
        and sequence length dimension. if set to ``True``, the input tensors
        should be in ``(seqlen, batch, hidden)``, otherwise ``(batch, seqlen, hidden)``.
1907
    scale_attn_logits: bool, default = False
1908
        Indicate whether to scale attention logits.
Paweł Gadziński's avatar
Paweł Gadziński committed
1909
1910
1911
1912
        if set to ``True``, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
        else :math:`Q \cdot K^T`
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\sqrt{head\_dim}`
1913
1914
1915
1916
1917
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1918
    num_gqa_groups: Optional[int] = None
1919
    layernorm_type: str = "layernorm"
1920
    layernorm_epsilon: float = 1e-6
1921
    zero_centered_gamma: bool = False
1922
1923
1924
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1925
    intermediate_dropout: float = 0.0
1926
    intermediate_dropout_dims: Sequence[int] = ()
1927
    dropout_rng_name: str = "dropout"
1928
1929
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1930
    mlp_activations: Sequence[str] = ("gelu",)
1931
    mlp_activation_params: dict = None
1932
1933
1934
1935
1936
1937
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1938
    self_attn_mask_type: str = "causal"
1939
    self_attn_bias_type: Optional[str] = None
1940
1941
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1942
1943
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1944
1945
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1946
1947
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1948
1949
1950
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1951
    transpose_batch_sequence: bool = False
1952
    enable_sequence_parallel: bool = False
1953
1954
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1955
    window_size: Optional[Tuple[int, int]] = None
1956
    softmax_type: str = "vanilla"
1957
1958
1959

    def __post_init__(self):
        if self.mha_kernel_init is None:
1960
            self.mha_kernel_init = nn.initializers.variance_scaling(
1961
                1.0, "fan_in", "normal", dtype=self.dtype
1962
            )
1963
        if self.mlp_kernel_init is None:
1964
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1965
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1966
            )
zlsh80826's avatar
zlsh80826 committed
1967
1968
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1969
1970
1971
        super().__post_init__()

    @nn.compact
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1982
1983
1984
1985
1986
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1987
        inputs: jax.numpy.ndarray
1988
            Input tensor.
1989
        encoded: jax.numpy.ndarray, default = None
1990
1991
1992
1993
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1994
            :attr:`True` means mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
1995
            Ignored when :attr:`self.self_attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
1996
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1997
1998
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1999
            :attr:`True` means mask out the corresponding values.
2000
        deterministic: bool, default = False
2001
            Disable dropout layers if set to True.
2002
        decode: bool, default = False
2003
2004
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
2005
        max_decode_length: bool, default = None
2006
2007
2008
2009
2010
2011
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
2012
        outputs: jax.numpy.ndarray
2013
            Output tensors.
2014
        """
2015

2016
        input_dtype = inputs.dtype
2017
2018
2019
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
2020

2021
2022
2023
2024
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
2025

2026
2027
2028
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
2029
2030
2031
2032
2033
2034

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

2035
2036
2037
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

2038
2039
2040
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
2041
2042
2043
2044
2045

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

2046
2047
2048
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
2049
2050
2051
2052
2053
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
2054
2055
2056
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
2057
2058
                    name="relpos_bias",
                )
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
2076
            mha_name = "attention"
2077
        else:
2078
            mha_name = "self_attention"
2079

2080
        inputs = with_sharding_constraint_by_logical_axes(
2081
2082
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2083

2084
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
2085
2086
2087
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
2088
2089
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2090
            num_gqa_groups=self.num_gqa_groups,
2091
            transpose_batch_sequence=self.transpose_batch_sequence,
2092
            enable_sequence_parallel=self.enable_sequence_parallel,
2093
            attention_dropout=self.attention_dropout,
2094
2095
2096
2097
2098
2099
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
2100
            zero_centered_gamma=self.zero_centered_gamma,
2101
2102
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
2103
            attn_mask_type=self.self_attn_mask_type,
2104
            attn_bias_type=self.self_attn_bias_type,
2105
2106
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2107
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2108
2109
2110
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2111
            fuse_qkv_params=self.fuse_qkv_params,
2112
2113
2114
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2115
            name=mha_name,
2116
            window_size=self.window_size,
2117
            softmax_type=self.softmax_type,
2118
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
2119
2120
2121
2122
2123

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
2124
                assert -x_shape_len <= dims < x_shape_len
2125

2126
2127
2128
2129
2130
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2131

2132
        x = with_sharding_constraint_by_logical_axes(
2133
2134
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2135
        residual = with_sharding_constraint_by_logical_axes(
2136
2137
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2138

2139
2140
2141
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
2142
2143
2144
2145
2146
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2147
2148
2149
2150
2151

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2152
2153
2154
2155
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
2156
2157
2158
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
2159

2160
            x = with_sharding_constraint_by_logical_axes(
2161
2162
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2163

2164
2165
2166
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
2167
2168
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2169
                num_gqa_groups=self.num_gqa_groups,
2170
                transpose_batch_sequence=self.transpose_batch_sequence,
2171
                enable_sequence_parallel=self.enable_sequence_parallel,
2172
                attention_dropout=self.attention_dropout,
2173
2174
2175
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
2176
                zero_centered_gamma=self.zero_centered_gamma,
2177
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
2178
2179
2180
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
2181
2182
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2183
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2184
2185
2186
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2187
2188
2189
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
2190
                fuse_qkv_params=self.fuse_qkv_params,
2191
2192
2193
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
2194
                name="encoder_decoder_attention",
2195
                window_size=self.window_size,
2196
                softmax_type=self.softmax_type,
2197
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
2198
2199

            y = with_sharding_constraint_by_logical_axes(
2200
2201
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2202
            residual = with_sharding_constraint_by_logical_axes(
2203
2204
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2205

2206
            y = hidden_dropout(y, deterministic)
2207
2208
2209
2210
2211

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

2212
2213
            mlp_input = y + residual

2214
        mlp_input = with_sharding_constraint_by_logical_axes(
2215
2216
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2217

2218
2219
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2220
2221
2222
2223
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2224
            zero_centered_gamma=self.zero_centered_gamma,
2225
2226
2227
2228
            epsilon=self.layernorm_epsilon,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2229
            activation_params=self.mlp_activation_params,
2230
2231
2232
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2233
            dtype=self.dtype,
2234
2235
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2236
            kernel_init=self.mlp_kernel_init,
2237
2238
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2239
2240
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2241
2242
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2243
2244
2245
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2246
2247
2248
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2249
            transpose_batch_sequence=self.transpose_batch_sequence,
2250
            name="mlp",
2251
2252
2253
2254
2255
2256
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2257
        z = with_sharding_constraint_by_logical_axes(
2258
2259
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2260
        residual = with_sharding_constraint_by_logical_axes(
2261
2262
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2263

2264
2265
2266
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2267
2268
2269
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2270
2271
2272
        z = z + residual

        if self.output_layernorm:
2273
            z = with_sharding_constraint_by_logical_axes(
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2285
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2286
        return z