transformer.py 101 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
18
from flax.linen.attention import combine_masks
19
20
21
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
27
28
29
30
31
32
from ..attention import (
    AttnBiasType,
    AttnMaskType,
    AttnSoftmaxType,
    QKVLayout,
    SequenceDescriptor,
)
33
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
34
from ..attention import fused_attn
35
from ..attention import CPStrategy
36
from ..softmax import SoftmaxFusionType
37
38
39
40
41
42
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
43
44
45
46
47

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
48
49
50
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
51
52
53
54
55
56
57
58
59
60
61
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


Phuong Nguyen's avatar
Phuong Nguyen committed
62
# TODO(Phuong): move this function to sharding.py
63
64
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
65
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
66
67
68
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
70
71
72
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
73
74

    .. warning::
75
        Please make sure ShardingResource is set via autocast before calling this function.
76

Ming-Xu Huang's avatar
Ming-Xu Huang committed
77
78
79
80
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

81
82
    Parameters
    ----------
83
    rules: Sequence[Tuple[str, Union[str, None]]]
84
85
86
87
        the base Flax logical axis rules to extend.

    Returns
    -------
88
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
89
90
91
92
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
93
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
94
95
        key = item[0]
        val = item[1]
96
97
98
99
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
100
101
102
103
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
104
105

    extended_rules = [*rules]
106
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
107
108
109
        key = item[0]
        val = item[1]
        if key in rules_map:
110
111
112
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
113
                f" rules, but {val} in TE's rules."
114
            )
115
116
117
118
119
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


120
121
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
122
123
124
125
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    float32_logits: bool = False
    scale_factor: Optional[float] = None
126
    transpose_batch_sequence: bool = False
127
    window_size: Optional[Tuple[int, int]] = None
128
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
129
130

    @nn.compact
131
132
133
134
135
136
137
138
139
140
141
142
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
143
        batch_dim = 1 if self.transpose_batch_sequence else 0
144
145
146
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
147
        sequence_dim = 0 if self.transpose_batch_sequence else 1
148
149
150
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
151

152
153
        input_dtype = query.dtype

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        # Infer number of attention heads from query shape
        # query shape: [..., h, d] where h is num_attention_heads
        num_attention_heads = query.shape[-2]

        # Initialize softmax_offset for learnable softmax
        # Note: OFF_BY_ONE_SOFTMAX is handled internally by the Softmax module
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

170
171
172
173
174
175
176
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
177
178
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
179
180
181
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
182
        is_gqa = h_q != h_kv
183

184
        if is_gqa:
185
186
187
188
189
190
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
191
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
192
            else:
193
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
194
        else:
195
            if is_gqa:
196
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
197
            else:
198
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
199

200
        attn_weights = checkpoint_name(attn_weights, "logits")
201

202
        if is_gqa:
203
204
205
206
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

207
        # (b, h, q, k): Last two axes are always replicated
208
        attn_weights = with_sharding_constraint_by_logical_axes(
209
            attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
210
        )
211
212
213
214
215

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
216
            fused_scale_factor = 1.0
217
        else:
218
219
220
221
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias
222
                bias = None
223

224
        def apply_swa_mask(original_mask: Array) -> Array:
225
            """Apply the sliding window mask to a given mask"""
226
            batch = original_mask.shape[0]
227
228
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
229
230
231
232
233
234
235
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
236
237
            return new_mask

238
239
        def convert_to_softmax_fusion_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxFusionType"""
240
241
242
243
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
244
                mask = None
245
            if mask is not None:
246
                mask = apply_swa_mask(mask)
247
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
248
            if mask is not None:
249
                return SoftmaxFusionType.SCALED_MASKED, mask
250
            if attn_mask_type is AttnMaskType.CAUSAL_MASK:
251
                return SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, mask
252
            if attn_mask_type is AttnMaskType.NO_MASK:
253
                return SoftmaxFusionType.SCALED, mask
254
255
256
257
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
258

259
        softmax_fusion_type, mask = convert_to_softmax_fusion_type(self.attn_mask_type, mask)
260

261
262
263
264
265
        attn_weights = Softmax(
            softmax_fusion_type=softmax_fusion_type,
            softmax_type=self.softmax_type,
            scale_factor=fused_scale_factor,
        )(attn_weights, mask, bias, softmax_offset=softmax_offset).astype(input_dtype)
266

267
        if is_gqa:
268
269
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

270
        if not deterministic and self.attention_dropout > 0.0:
271
272
273
274
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
275
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
276
277
            attn_weights = attn_weights * multiplier

278
279
280
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
281
282
        if self.transpose_batch_sequence:
            if is_gqa:
283
284
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
285
286

        if is_gqa:
287
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
288

289
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
290
291


292
293
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
294
295
296
297
298
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
299
    window_size: Optional[Tuple[int, int]] = None
300
    max_segments_per_seq: Optional[int] = 1
301
302
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
303
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
304
    context_checkpoint_name: str = "context"
305
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
306
307

    @nn.compact
308
309
310
311
312
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
313
        sequence_descriptor: Optional[SequenceDescriptor] = None,
314
315
316
317
318
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
319
320
321
322
323
324
325
326
327
328
329

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

330
331
332
333
334
335
336
337
338
339
340
        num_attention_heads = query.shape[-2]
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

341
        if self.qkv_layout.is_qkvpacked():
342
343
344
345
346
347
348
349
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
350
351
            x = fused_attn(
                (qkv_packed,),
352
                bias,
353
                sequence_descriptor,
354
355
356
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
357
                qkv_layout=self.qkv_layout,
358
                softmax_type=self.softmax_type,
359
360
361
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
362
                window_size=self.window_size,
363
                max_segments_per_seq=self.max_segments_per_seq,
364
365
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
366
                context_parallel_strategy=self.context_parallel_strategy,
367
                context_checkpoint_name=self.context_checkpoint_name,
368
                softmax_offset=softmax_offset,
369
            )
370
        elif self.qkv_layout.is_kvpacked():
371
372
373
374
375
376
377
378
379
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
380
381
            x = fused_attn(
                (query, kv_packed),
382
                bias,
383
                sequence_descriptor,
384
385
386
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
387
                qkv_layout=self.qkv_layout,
388
                softmax_type=self.softmax_type,
389
390
391
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
392
                window_size=self.window_size,
393
                max_segments_per_seq=self.max_segments_per_seq,
394
395
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
396
                context_parallel_strategy=self.context_parallel_strategy,
397
                context_checkpoint_name=self.context_checkpoint_name,
398
                softmax_offset=softmax_offset,
399
            )
400
        elif self.qkv_layout.is_separate():
401
402
403
404
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
405
            x = fused_attn(
406
                (query, key, value),
407
                bias,
408
                sequence_descriptor,
409
410
411
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
412
                qkv_layout=self.qkv_layout,
413
                softmax_type=self.softmax_type,
414
415
416
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
417
                window_size=self.window_size,
418
                max_segments_per_seq=self.max_segments_per_seq,
419
420
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
421
                context_parallel_strategy=self.context_parallel_strategy,
422
                context_checkpoint_name=self.context_checkpoint_name,
423
                softmax_offset=softmax_offset,
424
            )
425
426
427
428
429
430
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

431
        assert x.dtype == query.dtype
432
433
434
        return x


435
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

451
452
453
454
        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
          attention kernel is not available on the system, a warning will be issued, and the module
          will automatically fall back to the unfused backend.
455

456
457
458
459
460
461
462
463
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

464
465
466
467
468
469
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
Paweł Gadziński's avatar
Paweł Gadziński committed
470
    num_gqa_groups: int, default = None
471
472
473
474
475
476
477
478
479
480
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
481
482
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
483
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
484
485
486

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
487
        * ``no_mask``: No attention mask is applied. This means the attention will consider the
488
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
489
490
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
491
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
492
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
493
494
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
495
496
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
497

Paweł Gadziński's avatar
Paweł Gadziński committed
498
        |
499

Paweł Gadziński's avatar
Paweł Gadziński committed
500
        .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
501

Paweł Gadziński's avatar
Paweł Gadziński committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        |

        .. note:: THD format only supports ``'padding'`` or ``'causal_padding'`` mask type.

        |

        .. table::
            :widths: auto

            ================== ============ ========== ==============================
            attn_mask_type     mask/sd      SWA        softmax type
            ================== ============ ========== ==============================
            no_mask            None         None       SCALED
            causal             None         None       SCALED_UPPER_TRIANG_MASKED
            causal             None         Yes        SCALED_MASKED
            padding            Required     Yes/No     SCALED_MASKED
            padding_causal     Required     Yes/No     SCALED_MASKED
            ================== ============ ========== ==============================

        where sd stands for sequence_descriptor.
522

523
    attn_bias_type: Optional[str], default = None
524
        Type of the attention bias passed in the attention.
525
526
527
528
529
530
531
532
533
534
535
536
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
537
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
538
539
540
541
542
543

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
544
545
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
546
547
548
549
550
551
552
553
554
555
556
557

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
558
559
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
560
        Indicate whether the input tensors were switched axis of batch
561
        and sequence length dimension. If set to True, the input tensors
562
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
563
564
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
565
566
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
Paweł Gadziński's avatar
Paweł Gadziński committed
567
568
569
570
571
572
573
574
    context_parallel_causal_load_balanced: bool
        Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis: str
        The name of the context parallel axis.
    context_parallel_strategy: CPStrategy
        The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
    context_checkpoint_name: str
        The name of the context checkpoint in the forward pass of fused attention.
575
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
576
        Softmax type as described in the paper
577
578
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
601
    """
602

603
604
605
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
606
607
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
608
    attn_bias_type: AttnBiasType = None
609
    dropout_rng_name: str = "dropout"
610
    float32_logits: bool = False
611
    qkv_layout: str = "bshd_bshd_bshd"
612
    scale_factor: Optional[float] = None
613
    transpose_batch_sequence: bool | None = None
614
    window_size: Optional[Tuple[int, int]] = None
615
    max_segments_per_seq: Optional[int] = 1
616
617
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
618
    context_parallel_strategy: str = "DEFAULT"
619
    context_checkpoint_name: str = "context"
620
    softmax_type: str = "vanilla"
621

622
623
624
625
626
627
628
629
630
631
632
    def __post_init__(self):
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in DotProductAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
        super().__post_init__()

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    def _assert_dtypes(self, query: Array, key: Array, value: Array, qkv_layout: QKVLayout):
        """Asserts that the dtypes of query, key, and value dtypes are consistent."""
        if qkv_layout.is_qkvpacked():
            pass  # No need to check dtypes for key and value since it is packed
        elif qkv_layout.is_kvpacked():
            assert (
                key.dtype == query.dtype
            ), f"Expected kv dtype={key.dtype} to match query dtype={query.dtype}."
        elif qkv_layout.is_separate():
            assert (
                key.dtype == query.dtype
            ), f"Expected key dtype={key.dtype} to match query dtype={query.dtype}."
            assert (
                value.dtype == query.dtype
            ), f"Expected value dtype={value.dtype} to match query dtype={query.dtype}."
        else:
            raise ValueError(f"Unsupported {qkv_layout=}.")

651
    @nn.compact
652
653
654
655
656
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
657
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
658
659
660
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
661
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
662
    ) -> Array:
663
664
665
666
667
668
669
670
671
672
673
674
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
675
            Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
676
677
678
679
680
681
682
683
684
685
686
687
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
688
        input_dtype = query.dtype
689

690
691
692
693
694
695
696
697
698
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

699
700
701
702
703
704
705
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
706
        softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
707
708
709
710
711
712
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None
713
714
715
            bias = bias.astype(input_dtype)

        self._assert_dtypes(query, key, value, qkv_layout)
716

717
718
        # Use fused attn (if kernel check below passes) by default
        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
719
720
721
722
723
724
725

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]
726
727
728
729
730
731
        if qkv_layout.is_separate():
            head_dim_qk = query.shape[-1]
            head_dim_v = value.shape[-1]
        else:
            head_dim_qk = self.head_dim
            head_dim_v = self.head_dim
732

733
        has_fused_attn_kernel = is_fused_attn_kernel_available(
734
735
            # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
            not deterministic,
736
737
738
            input_dtype,
            # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient
            input_dtype,
739
740
741
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
742
            softmax_type,
743
744
745
746
747
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
748
749
            head_dim_qk,
            head_dim_v,
750
            self.window_size,
751
        )
752

753
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
754
755

        if enable_fused_attn and not has_fused_attn_kernel:
756
757
758
759
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
760
                f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
761
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
762
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
763
            )
764
765

        dropout_rng = None
766
        if not deterministic and self.attention_dropout > 0.0:
767
768
769
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
770
            scale_factor = 1.0 / sqrt(head_dim_qk)
771
772
773
774
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        # case-insensitive mapping for context parallel strategy
        cp_strategy_map = {
            "DEFAULT": CPStrategy.DEFAULT,
            "ALL_GATHER": CPStrategy.ALL_GATHER,
            "ALLGATHER": CPStrategy.ALL_GATHER,  # Alternative spelling
            "RING": CPStrategy.RING,
        }

        strategy_key = self.context_parallel_strategy.upper()
        if strategy_key in cp_strategy_map:
            context_parallel_strategy = cp_strategy_map[strategy_key]
        else:
            valid_strategies = list(cp_strategy_map.keys())
            raise ValueError(
                f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
                f"Valid options are: {valid_strategies} (case insensitive)"
            )

793
794
        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
795
            if qkv_layout.is_qkvpacked():
796
                query, key, value = jnp.split(query, [1, 2], axis=-3)
797
798
799
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
800
            elif qkv_layout.is_kvpacked():
801
802
803
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
804
805
                assert qkv_layout.is_separate()

806
807
808
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
809

810
811
812
813
814
815
816
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
817
                window_size=self.window_size,
818
                softmax_type=softmax_type,
819
820
821
822
823
824
825
826
827
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
828
829
830
831
832
833
834
835
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
836
                window_size=self.window_size,
837
                max_segments_per_seq=self.max_segments_per_seq,
838
839
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
840
                context_parallel_strategy=context_parallel_strategy,
841
                context_checkpoint_name=self.context_checkpoint_name,
842
                softmax_type=softmax_type,
843
844
845
846
847
848
849
850
851
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
852
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
853
        return x
854
855


856
857
858
859
860
861
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
862
863
    """
    Rotary Positional Embedding
Paweł Gadziński's avatar
Paweł Gadziński committed
864
    x should be of shape
865
866
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
867
    """
868
869
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
870
871
872
    min_window = windows[0]
    max_window = windows[1]

873
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
874
    time_scales = min_window * (max_window / min_window) ** fraction
875
876
877
878
879
880
881
882
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

883
884
885
886
887
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
888

889
890
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
891

892
        x1, x2 = jnp.split(x, 2, axis=-1)
893
894
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
895

896
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
920
921
922
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
923
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
924
        )
925
926
927
928
929

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

930
    if group_method == "alternate":
931
932
        return alternate_impl()
    return consecutive_impl()
933
934


935
class LoRAScope:  # pylint: disable=too-few-public-methods
936
937
938
939
940
941
942
943
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
944
945
946
947
948
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
949
950
951
952


def _canonicalize_lora_scope(scope):

953
954
955
956
957
958
959
960
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
961
962
963
964
965
966

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
967
968
969
970
971
972
973
974
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


991
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
992
993
994
995
996
997
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
998
    head_dim: int
999
        The hidden dimension of each attention head.
1000
1001
    num_attention_heads: int
        The number of attention heads.
Paweł Gadziński's avatar
Paweł Gadziński committed
1002
    num_gqa_groups: int, default = None
1003
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
1004
1005
1006
1007
1008
1009
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1010
1011
1012
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
1013
1014
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
Paweł Gadziński's avatar
Paweł Gadziński committed
1015
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
1016
1017
1018

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
1019
        * ``no_mask``: No attention mask is applied. This means the attention will consider the
1020
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1021
1022
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
1023
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1024
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
1025
1026
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
1027
1028
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
1029

Paweł Gadziński's avatar
Paweł Gadziński committed
1030
        .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
1031

1032
1033
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1034
        Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
1035
        When default is present, the type is automatically decided by the MHA's bias parameter.
Paweł Gadziński's avatar
Paweł Gadziński committed
1036
        Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
1037
    dropout_rng_name: str, default = 'dropout'
1038
1039
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
1040
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1041
        Indicate the type of layer normalization.
1042
    layernorm_epsilon: float, default = 1e-6
1043
        A value added to the denominator of layer normalization for numerical stability.
1044
    zero_centered_gamma: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1045
        If set to ``True``, the LayerNorm formula changes to
1046
1047

        .. math::
Paweł Gadziński's avatar
Paweł Gadziński committed
1048
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
1049
1050
            (1 + \gamma) + \beta

Paweł Gadziński's avatar
Paweł Gadziński committed
1051
        This parameter is only applicable for ``'layernorm'``.
1052
    kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1053
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
1054
        Used for initializing the QKV and output projection weights.
Paweł Gadziński's avatar
Paweł Gadziński committed
1055
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1056
    use_bias: bool, default = False
1057
        Indicate whether or not to enable bias shifting for QKV and output projections.
Paweł Gadziński's avatar
Paweł Gadziński committed
1058
1059
        If set to ``False``, the layer will not learn additive biases.
    bias_init: Initializer, default = ``flax.linen.initializers.zeros``
1060
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1061
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1062
    input_layernorm: bool, default = True
Paweł Gadziński's avatar
Paweł Gadziński committed
1063
        If set to ``False``, layer normalization to the input is not applied.
1064
    return_layernorm_output: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1065
        If set to ``True``, output of layernorm is returned from the forward together with the output
1066
1067
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
1068
1069
1070
1071
1072
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1073
1074
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1075
1076
        ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`.
1077
1078
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1079
        ``['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']``
1080
1081
1082
1083
1084
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
Paweł Gadziński's avatar
Paweł Gadziński committed
1085
        :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
1086
1087
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1088
1089
1090
1091
1092
1093
1094
1095
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
1096
1097
1098

    Optimization parameters
    -----------------------
1099
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1100
        The data type used to allocate the initial parameters.
1101
    fuse_qkv_params: bool, default = True
1102
        If set to True, this module exposes a single fused
1103
1104
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1105
1106
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
1107
        Indicate whether the input tensors were switched axis of batch
1108
1109
1110
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1111
        Indicate whether to scale attention logits.
Paweł Gadziński's avatar
Paweł Gadziński committed
1112
1113
        If set to True, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
        else :math:`Q \cdot K^T`
1114
1115
1116
1117
1118
1119
1120
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
1121
1122
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1123
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
1124
        Softmax type as described in the paper
1125
1126
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
1149
1150
1151
    """

    head_dim: int
1152
1153
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
1154
1155
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
1156
    input_layernorm: bool = True
1157
1158
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
1159
    return_layernorm_output: bool = False
1160
    zero_centered_gamma: bool = False
1161
1162
1163
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
1164
    attn_mask_type: str = "causal"
1165
    attn_bias_type: Optional[str] = None
1166
1167
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1168
1169
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1170
1171
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1172
    dtype: DType = jnp.float32
1173
    fuse_qkv_params: bool = True
1174
    transpose_batch_sequence: bool | None = None
1175
    enable_sequence_parallel: bool = False
1176
1177
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1178
    float32_logits: bool = False
1179
    window_size: Optional[Tuple[int, int]] = None
1180
    softmax_type: str = "vanilla"
1181
1182
1183
1184
1185
1186
1187

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
1188
1189

    def __post_init__(self):
1190
1191
1192
1193
1194
1195
1196
1197
1198
        # Deal with changed defaults in API
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in MultiHeadAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
1199
1200
1201
1202
1203
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1204
1205
1206
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1207
1208
1209
1210
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1211
1212
1213
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1214
1215
1216
1217
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1218
1219
                DeprecationWarning,
            )
1220
1221
1222
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1223
1224
1225
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1226
1227
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1228
1229
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1230

1231
        if self.kernel_init is None:
1232
            self.kernel_init = nn.initializers.variance_scaling(
1233
                1.0, "fan_in", "normal", dtype=self.dtype
1234
            )
zlsh80826's avatar
zlsh80826 committed
1235
        if self.num_gqa_groups is None:
1236
            self.num_gqa_groups = self.num_attention_heads
1237
1238
1239
        super().__post_init__()

    @nn.compact
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1250
1251
1252
1253
1254
1255
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1256
        inputs_q: jax.numpy.ndarray
1257
            Input tensor for query projection.
1258
        inputs_kv: jax.numpy.ndarray
1259
            Input tensor for key/value projection.
1260
1261
1262
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
1263
            Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
1264
1265
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1266
        *
1267
        decode: bool, default = False
1268
            Indicate whether to prepare and use an autoregressive cache.
1269
        deterministic: bool, default = False
1270
1271
1272
1273
            Disable dropout layers if set to True.

        Returns
        -------
1274
        outputs: jax.numpy.ndarray
1275
1276
            Output tensors.
        """
1277

1278
1279
1280
1281
1282
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1283
        def query_init(*args):
1284
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1327
1328
1329
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1330

1331
1332
1333
1334
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1335
1336
1337
1338
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1339
1340
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1341
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1342
            if is_qkvpack:
1343
                qkv_proj, ln_out = LayerNormDenseGeneral(
1344
                    enable_layernorm=self.input_layernorm,
1345
                    layernorm_type=self.layernorm_type,
1346
                    zero_centered_gamma=self.zero_centered_gamma,
1347
1348
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1349
1350
                    features=(3, self.num_attention_heads * self.head_dim),
                    return_layernorm_output=self.return_layernorm_output,
1351
1352
1353
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1354
1355
1356
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1357
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1358
1359
1360
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1361
1362
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1363
                    transpose_batch_sequence=self.transpose_batch_sequence,
1364
1365
1366
1367
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1368
                qkv_layout = QKVLayout.BS3HD
1369
1370
            else:
                query, ln_out = LayerNormDenseGeneral(
1371
                    enable_layernorm=self.input_layernorm,
1372
                    layernorm_type=self.layernorm_type,
1373
                    zero_centered_gamma=self.zero_centered_gamma,
1374
1375
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1376
1377
                    features=self.num_attention_heads * self.head_dim,
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1378
1379
1380
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1381
1382
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1383
                    bias_axes=(W_TP_AXES,),
1384
1385
1386
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1387
1388
                    dtype=self.dtype,
                    kernel_init=query_init,
1389
1390
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1391
                    transpose_batch_sequence=self.transpose_batch_sequence,
1392
1393
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1394
1395
1396
1397
1398

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1410
                    transpose_batch_sequence=self.transpose_batch_sequence,
1411
1412
1413
1414
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1415
                qkv_layout = QKVLayout.BSHD_BS2HD
1416
1417
1418
1419
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1420
                features=self.num_gqa_groups * self.head_dim,
1421
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1422
1423
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1424
                bias_axes=(W_TP_AXES,),
1425
1426
1427
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1428
1429
                dtype=self.dtype,
            )
1430
            query, ln_out = LayerNormDenseGeneral(
1431
                enable_layernorm=self.input_layernorm,
1432
                layernorm_type=self.layernorm_type,
1433
                zero_centered_gamma=self.zero_centered_gamma,
1434
1435
                epsilon=self.layernorm_epsilon,
                axis=-1,
1436
                features=self.num_attention_heads * self.head_dim,
1437
                return_layernorm_output=True,
1438
1439
1440
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1441
1442
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1443
                bias_axes=(W_TP_AXES,),
1444
1445
1446
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1447
1448
                dtype=self.dtype,
                kernel_init=query_init,
1449
1450
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1451
                transpose_batch_sequence=self.transpose_batch_sequence,
1452
1453
                name="query",
            )(inputs_q)
1454

1455
            if is_self_attn:
1456
1457
1458
                assert ln_out is not None
                inputs_kv = ln_out

1459
            query = query.astype(input_dtype)
1460
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1461
            key = key.astype(input_dtype)
1462
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1463
            value = value.astype(input_dtype)
1464
1465
1466
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1467
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1468

1469
        if self.enable_rotary_pos_emb:
1470
1471
1472
1473
1474
1475
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1476

1477
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1478
1479
1480
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1493
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1494

1495
1496
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1497
1498
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1499
1500

        if decode:
1501
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1502
1503
1504
1505
1506
1507
1508
1509
1510
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1511
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1512
                if self.transpose_batch_sequence:
1513
1514
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1515
1516
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1517
1518
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1519
                    one_hot_indices_shape = (1, length, 1, 1)
1520
1521
1522
1523

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1524
1525
1526
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1527

1528
                cur_index = cache_index.value.astype(jnp.int32)
1529
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1530
1531
1532
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1533
1534
1535
1536
1537
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1538
1539
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1540
1541

                if bias is not None:
1542
1543
1544
1545
1546
1547
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1548

1549
1550
1551
1552
1553
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1554
1555
1556
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1568
        else:
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1579
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1592
            window_size=self.window_size,
1593
            softmax_type=self.softmax_type,
1594
        )(*dpa_args, mask, bias, deterministic=deterministic)
1595
1596
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1597
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1598
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1599

1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1615

1616
1617
1618
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1619
        return out, ln_out
1620
1621


1622
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1623
1624
1625
1626
1627
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1628
    num_buckets: int
1629
        The number of buckets to bucket distances between key and query positions into.
1630
    max_distance: int
1631
        The maximum distance before everything is lumped into the last
1632
        distance bucket.
1633
    num_attention_heads: int
1634
        Number of attention heads in the transformer layer.
1635
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1636
        Used for initializing relative embedding tables.
1637
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1638
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1639
1640
1641

    Optimization parameters
    -----------------------
1642
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1643
        The data type used to allocate the initial parameters.
1644
    """
1645

1646
1647
1648
1649
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1650
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1651
1652
1653
1654
1655
1656
1657
1658
1659
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1660
        q_seqlen: int
1661
            The sequence length of query.
1662
        k_seqlen: int
1663
            The sequence length of key.
1664
        bidirectional: bool, default = True
1665
            Indicate whether to allow positive memory-query relative position
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1692
1693
1694
1695
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1696
1697
1698
1699
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
1700
        relative_attention_bias = self.param(
1701
            "rel_embedding",
1702
            nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
1703
            (self.num_attention_heads, self.num_buckets),
1704
            self.dtype,
1705
        )
1706
1707
1708
1709
1710
1711

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1712
1713
1714
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1715
1716
1717
1718
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1729

1730
1731
1732
1733
    ENCODER = "encoder"
    DECODER = "decoder"


1734
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1735
1736
1737
1738
1739
1740
1741
1742
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1743
        The hidden size of each input sample.
1744
    mlp_hidden_size: int, default = 2048
1745
        Intermediate size to which input samples are projected.
1746
    num_attention_heads: int, default = 8
1747
        Number of attention heads in the transformer layer.
Paweł Gadziński's avatar
Paweł Gadziński committed
1748
    num_gqa_groups: int, default = None
zlsh80826's avatar
zlsh80826 committed
1749
1750
1751
1752
1753
1754
1755
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1756
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1757
        Indicate the type of layer normalization.
1758
    layernorm_epsilon: float, default = 1e-6
1759
        A value added to the denominator of layer normalization for numerical stability.
1760
    zero_centered_gamma: bool, default = False
1761
1762
1763
1764
1765
1766
1767
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1768
    hidden_dropout: float, default = 0.1
1769
        Dropout probability for the dropout op after FC2 layer.
1770
    hidden_dropout_dims: Sequence[int], default = ()
1771
        Dimensions that will share the same dropout mask for hidden
1772
    attention_dropout: float, default = 0.1
1773
        Dropout probability for the dropout op during multi-head attention.
1774
    intermediate_dropout: float, default = 0.0
1775
1776
1777
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1778
    dropout_rng_name: str, default = 'dropout'
1779
1780
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1781
    mha_kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1782
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
1783
        Used for initializing weights of QKV and Output projection weights.
Paweł Gadziński's avatar
Paweł Gadziński committed
1784
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1785
    mlp_kernel_init: Initializer, default =
Paweł Gadziński's avatar
Paweł Gadziński committed
1786
        ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')``
1787
        Used for initializing weights of FC1 and FC2 layers.
Paweł Gadziński's avatar
Paweł Gadziński committed
1788
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1789
    mlp_activations: Sequence[str], default = ('gelu', )
1790
        The sequence of activation functions to apply after the first linear transformation.
1791
        Each activation has its own transformation layer.
1792
    mlp_activation_params: dict = None
Paweł Gadziński's avatar
Paweł Gadziński committed
1793
1794
         This is only used when ``('clamped_silu', 'clamped_linear')`` is in :attr:`mlp_activations`. At the moment
        ``ClampedSwiglu`` is the only activation that requires parameters.
1795
    use_bias: bool, default = False
1796
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
Paweł Gadziński's avatar
Paweł Gadziński committed
1797
1798
        If set to ``False``, the layer will not learn additive biases.
    bias_init: Initializer, default = ``flax.linen.initializers.zeros``
1799
1800
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1801
        It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
1802
    apply_residual_connection_post_layernorm: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1803
        If set to ``True``, residual connections are taken from the output
1804
1805
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
Paweł Gadziński's avatar
Paweł Gadziński committed
1806
        If set to ``True``, layer normalization is applied on the output side,
1807
1808
1809
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1810
1811
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1812
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1813
        If set to TransformerLayerType.DECODER, an additional cross-attention block
Paweł Gadziński's avatar
Paweł Gadziński committed
1814
        is added after self-attention.this can be used for structures like T5
1815
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1816
    self_attn_mask_type: str, default = 'causal'
1817
1818
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1819
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
1820
1821
1822

        Each described below:

Paweł Gadziński's avatar
Paweł Gadziński committed
1823
        * ``no_mask``: No attention mask is applied. This means the self attention will consider the
1824
          full sequence without any restrictions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1825
1826
        * ``padding``: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
1827
          :attr:`__call__` method to specify the padding positions.
Paweł Gadziński's avatar
Paweł Gadziński committed
1828
        * ``causal``: An upper triangular mask is applied to the softmax inputs,
1829
1830
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
Paweł Gadziński's avatar
Paweł Gadziński committed
1831
1832
        * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
          Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
1833

Paweł Gadziński's avatar
Paweł Gadziński committed
1834
        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
1835

1836
1837
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
1838
        Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
1839
        When default is present, the type is automatically decided by the MHA's bias parameter.
Paweł Gadziński's avatar
Paweł Gadziński committed
1840
        Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
1841
    enable_relative_embedding: bool, default = True
1842
        Whether to enable relative embedding as shifting of attention logits.
1843
    relative_embedding: flax.linen.Module, default = None
1844
        The module for relative embedding execution, only used when
Paweł Gadziński's avatar
Paweł Gadziński committed
1845
        :attr:`enable_relative_embedding=True`. Default is ``None``, which will create
1846
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Paweł Gadziński's avatar
Paweł Gadziński committed
1847
        Default: ``RelativePositionBiases( num_buckets=32, max_distance=128,
1848
1849
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
Paweł Gadziński's avatar
Paweł Gadziński committed
1850
        name='relpos_bias')``
1851
1852
1853
1854
1855
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1856
    rotary_pos_emb_group_method: str, default = 'consecutive'
1857
        Indicate the method to couple the coordinates. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1858
1859
        ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with
1860
        :math:`i + 1`.
1861
1862
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
Paweł Gadziński's avatar
Paweł Gadziński committed
1863
1864
        ``['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
        'exclude_output_proj', 'exclude_mlp']``
1865
1866
1867
1868
1869
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
Paweł Gadziński's avatar
Paweł Gadziński committed
1870
        :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
1871
1872
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1873
1874
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1875
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Paweł Gadziński's avatar
Paweł Gadziński committed
1876
        Softmax type as described in the paper
1877
1878
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900

        For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

        * ``'vanilla'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

        * ``'off-by-one'``:

          .. math::
             Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

        * ``'learnable'``:

          .. math::
             Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

          where :math:`\alpha` is a learnable parameter of shape ``[h]``.

        ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
        (``'zero sink'`` and ``'learnable sink'``).
1901
        Only supported for fused attention backend.
1902
1903
1904

    Optimization parameters
    -----------------------
1905
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1906
        The data type used to allocate the initial parameters.
1907
    drop_path: float, default = 0.0
1908
        When > 0.0, applies stochastic depth per sample in the main
1909
1910
        path of the residual block.
    fuse_qkv_params: bool, default = True
Paweł Gadziński's avatar
Paweł Gadziński committed
1911
        If set to ``True``, ``TransformerLayer`` module exposes a single fused
1912
1913
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1914
    transpose_batch_sequence: bool, default = False
1915
        Indicate whether the input tensors were switched axis of batch
Paweł Gadziński's avatar
Paweł Gadziński committed
1916
1917
        and sequence length dimension. if set to ``True``, the input tensors
        should be in ``(seqlen, batch, hidden)``, otherwise ``(batch, seqlen, hidden)``.
1918
    scale_attn_logits: bool, default = False
1919
        Indicate whether to scale attention logits.
Paweł Gadziński's avatar
Paweł Gadziński committed
1920
1921
1922
1923
        if set to ``True``, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
        else :math:`Q \cdot K^T`
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\sqrt{head\_dim}`
1924
1925
1926
1927
1928
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1929
    num_gqa_groups: Optional[int] = None
1930
    layernorm_type: str = "layernorm"
1931
    layernorm_epsilon: float = 1e-6
1932
    zero_centered_gamma: bool = False
1933
1934
1935
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1936
    intermediate_dropout: float = 0.0
1937
    intermediate_dropout_dims: Sequence[int] = ()
1938
    dropout_rng_name: str = "dropout"
1939
1940
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1941
    mlp_activations: Sequence[str] = ("gelu",)
1942
    mlp_activation_params: dict = None
1943
1944
1945
1946
1947
1948
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1949
    self_attn_mask_type: str = "causal"
1950
    self_attn_bias_type: Optional[str] = None
1951
1952
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1953
1954
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1955
1956
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1957
1958
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1959
1960
1961
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1962
    transpose_batch_sequence: bool = False
1963
    enable_sequence_parallel: bool = False
1964
1965
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1966
    window_size: Optional[Tuple[int, int]] = None
1967
    softmax_type: str = "vanilla"
1968
1969
1970

    def __post_init__(self):
        if self.mha_kernel_init is None:
1971
            self.mha_kernel_init = nn.initializers.variance_scaling(
1972
                1.0, "fan_in", "normal", dtype=self.dtype
1973
            )
1974
        if self.mlp_kernel_init is None:
1975
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1976
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1977
            )
zlsh80826's avatar
zlsh80826 committed
1978
1979
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1980
1981
1982
        super().__post_init__()

    @nn.compact
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1993
1994
1995
1996
1997
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1998
        inputs: jax.numpy.ndarray
1999
            Input tensor.
2000
        encoded: jax.numpy.ndarray, default = None
2001
2002
2003
2004
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
2005
            :attr:`True` means mask out the corresponding values.
Paweł Gadziński's avatar
Paweł Gadziński committed
2006
            Ignored when :attr:`self.self_attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
2007
        encoder_decoder_mask: jax.numpy.ndarray, default = None
2008
2009
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
2010
            :attr:`True` means mask out the corresponding values.
2011
        deterministic: bool, default = False
2012
            Disable dropout layers if set to True.
2013
        decode: bool, default = False
2014
2015
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
2016
        max_decode_length: bool, default = None
2017
2018
2019
2020
2021
2022
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
2023
        outputs: jax.numpy.ndarray
2024
            Output tensors.
2025
        """
2026

2027
        input_dtype = inputs.dtype
2028
2029
2030
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
2031

2032
2033
2034
2035
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
2036

2037
2038
2039
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
2040
2041
2042
2043
2044
2045

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

2046
2047
2048
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

2049
2050
2051
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
2052
2053
2054
2055
2056

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

2057
2058
2059
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
2060
2061
2062
2063
2064
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
2065
2066
2067
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
2068
2069
                    name="relpos_bias",
                )
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
2087
            mha_name = "attention"
2088
        else:
2089
            mha_name = "self_attention"
2090

2091
        inputs = with_sharding_constraint_by_logical_axes(
2092
2093
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2094

2095
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
2096
2097
2098
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
2099
2100
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2101
            num_gqa_groups=self.num_gqa_groups,
2102
            transpose_batch_sequence=self.transpose_batch_sequence,
2103
            enable_sequence_parallel=self.enable_sequence_parallel,
2104
            attention_dropout=self.attention_dropout,
2105
2106
2107
2108
2109
2110
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
2111
            zero_centered_gamma=self.zero_centered_gamma,
2112
2113
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
2114
            attn_mask_type=self.self_attn_mask_type,
2115
            attn_bias_type=self.self_attn_bias_type,
2116
2117
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2118
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2119
2120
2121
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2122
            fuse_qkv_params=self.fuse_qkv_params,
2123
2124
2125
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2126
            name=mha_name,
2127
            window_size=self.window_size,
2128
            softmax_type=self.softmax_type,
2129
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
2130
2131
2132
2133
2134

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
2135
                assert -x_shape_len <= dims < x_shape_len
2136

2137
2138
2139
2140
2141
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2142

2143
        x = with_sharding_constraint_by_logical_axes(
2144
2145
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2146
        residual = with_sharding_constraint_by_logical_axes(
2147
2148
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2149

2150
2151
2152
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
2153
2154
2155
2156
2157
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2158
2159
2160
2161
2162

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2163
2164
2165
2166
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
2167
2168
2169
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
2170

2171
            x = with_sharding_constraint_by_logical_axes(
2172
2173
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2174

2175
2176
2177
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
2178
2179
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2180
                num_gqa_groups=self.num_gqa_groups,
2181
                transpose_batch_sequence=self.transpose_batch_sequence,
2182
                enable_sequence_parallel=self.enable_sequence_parallel,
2183
                attention_dropout=self.attention_dropout,
2184
2185
2186
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
2187
                zero_centered_gamma=self.zero_centered_gamma,
2188
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
2189
2190
2191
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
2192
2193
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2194
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2195
2196
2197
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2198
2199
2200
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
2201
                fuse_qkv_params=self.fuse_qkv_params,
2202
2203
2204
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
2205
                name="encoder_decoder_attention",
2206
                window_size=self.window_size,
2207
                softmax_type=self.softmax_type,
2208
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
2209
2210

            y = with_sharding_constraint_by_logical_axes(
2211
2212
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2213
            residual = with_sharding_constraint_by_logical_axes(
2214
2215
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2216

2217
            y = hidden_dropout(y, deterministic)
2218
2219
2220
2221
2222

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

2223
2224
            mlp_input = y + residual

2225
        mlp_input = with_sharding_constraint_by_logical_axes(
2226
2227
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2228

2229
2230
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2231
2232
2233
2234
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2235
            zero_centered_gamma=self.zero_centered_gamma,
2236
2237
2238
2239
            epsilon=self.layernorm_epsilon,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2240
            activation_params=self.mlp_activation_params,
2241
2242
2243
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2244
            dtype=self.dtype,
2245
2246
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2247
            kernel_init=self.mlp_kernel_init,
2248
2249
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2250
2251
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2252
2253
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2254
2255
2256
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2257
2258
2259
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2260
            transpose_batch_sequence=self.transpose_batch_sequence,
2261
            name="mlp",
2262
2263
2264
2265
2266
2267
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2268
        z = with_sharding_constraint_by_logical_axes(
2269
2270
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2271
        residual = with_sharding_constraint_by_logical_axes(
2272
2273
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2274

2275
2276
2277
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2278
2279
2280
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2281
2282
2283
        z = z + residual

        if self.output_layernorm:
2284
            z = with_sharding_constraint_by_logical_axes(
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2296
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2297
        return z