transformer.py 99 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
18
from flax.linen.attention import combine_masks
19
20
21
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
27
28
29
30
31
32
from ..attention import (
    AttnBiasType,
    AttnMaskType,
    AttnSoftmaxType,
    QKVLayout,
    SequenceDescriptor,
)
33
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
34
from ..attention import fused_attn
35
from ..attention import CPStrategy
36
from ..softmax import SoftmaxFusionType
37
38
39
40
41
42
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
43
44
45
46
47

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
48
49
50
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
51
52
53
54
55
56
57
58
59
60
61
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


Phuong Nguyen's avatar
Phuong Nguyen committed
62
# TODO(Phuong): move this function to sharding.py
63
64
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
65
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
66
67
68
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
70
71
72
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
73
74

    .. warning::
75
        Please make sure ShardingResource is set via autocast before calling this function.
76

Ming-Xu Huang's avatar
Ming-Xu Huang committed
77
78
79
80
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

81
82
    Parameters
    ----------
83
    rules: Sequence[Tuple[str, Union[str, None]]]
84
85
86
87
        the base Flax logical axis rules to extend.

    Returns
    -------
88
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
89
90
91
92
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
93
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
94
95
        key = item[0]
        val = item[1]
96
97
98
99
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
100
101
102
103
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
104
105

    extended_rules = [*rules]
106
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
107
108
109
        key = item[0]
        val = item[1]
        if key in rules_map:
110
111
112
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
113
                f" rules, but {val} in TE's rules."
114
            )
115
116
117
118
119
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


120
121
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
122
123
124
125
126
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
127
    transpose_batch_sequence: bool = False
128
    window_size: Optional[Tuple[int, int]] = None
129
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
130
131

    @nn.compact
132
133
134
135
136
137
138
139
140
141
142
143
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
144
        batch_dim = 1 if self.transpose_batch_sequence else 0
145
146
147
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
148
        sequence_dim = 0 if self.transpose_batch_sequence else 1
149
150
151
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
152

153
154
        input_dtype = query.dtype

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        # Infer number of attention heads from query shape
        # query shape: [..., h, d] where h is num_attention_heads
        num_attention_heads = query.shape[-2]

        # Initialize softmax_offset for learnable softmax
        # Note: OFF_BY_ONE_SOFTMAX is handled internally by the Softmax module
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

171
172
173
174
175
176
177
        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
178
179
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
180
181
182
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
183
        is_gqa = h_q != h_kv
184

185
        if is_gqa:
186
187
188
189
190
191
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
192
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
193
            else:
194
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
195
        else:
196
            if is_gqa:
197
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
198
            else:
199
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
200

201
        attn_weights = checkpoint_name(attn_weights, "logits")
202

203
        if is_gqa:
204
205
206
207
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

208
        # (b, h, q, k): Last two axes are always replicated
209
        attn_weights = with_sharding_constraint_by_logical_axes(
210
            attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
211
        )
212
213
214
215
216

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
217
            fused_scale_factor = 1.0
218
        else:
219
220
221
222
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias
223
                bias = None
224

225
        def apply_swa_mask(original_mask: Array) -> Array:
226
            """Apply the sliding window mask to a given mask"""
227
            batch = original_mask.shape[0]
228
229
            max_seqlen_q = original_mask.shape[-2]
            max_seqlen_kv = original_mask.shape[-1]
230
231
232
233
234
235
236
            # TODO(rewang): Support THD format pos
            pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
            pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
            # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
            inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
            swa_mask = 1 - inv_swa_mask
            new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
237
238
            return new_mask

239
240
        def convert_to_softmax_fusion_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxFusionType"""
241
242
243
244
            # mask is ignored for no_mask and causal_mask without sliding window
            if attn_mask_type == AttnMaskType.NO_MASK:
                mask = None
            if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
245
                mask = None
246
            if mask is not None:
247
                mask = apply_swa_mask(mask)
248
            # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
249
            if mask is not None:
250
                return SoftmaxFusionType.SCALED_MASKED, mask
251
            if attn_mask_type is AttnMaskType.CAUSAL_MASK:
252
                return SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, mask
253
            if attn_mask_type is AttnMaskType.NO_MASK:
254
                return SoftmaxFusionType.SCALED, mask
255
256
257
258
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
259

260
        softmax_fusion_type, mask = convert_to_softmax_fusion_type(self.attn_mask_type, mask)
261

262
263
264
265
266
        attn_weights = Softmax(
            softmax_fusion_type=softmax_fusion_type,
            softmax_type=self.softmax_type,
            scale_factor=fused_scale_factor,
        )(attn_weights, mask, bias, softmax_offset=softmax_offset).astype(input_dtype)
267

268
        if is_gqa:
269
270
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

271
        if not deterministic and self.attention_dropout > 0.0:
272
273
274
275
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
276
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
277
278
            attn_weights = attn_weights * multiplier

279
280
281
        assert (
            attn_weights.dtype == input_dtype
        ), f"output={attn_weights.dtype}, input={input_dtype}"
282
283
        if self.transpose_batch_sequence:
            if is_gqa:
284
285
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
286
287

        if is_gqa:
288
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
289

290
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
291
292


293
294
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
295
296
297
298
299
300
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False
301
    window_size: Optional[Tuple[int, int]] = None
302
    max_segments_per_seq: Optional[int] = 1
303
304
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
305
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
306
    context_checkpoint_name: str = "context"
307
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
308
309

    @nn.compact
310
311
312
313
314
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
315
        sequence_descriptor: Optional[SequenceDescriptor] = None,
316
317
318
319
320
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
321
322
323
324
325
326
327
328
329
330
331

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

332
333
334
335
336
337
338
339
340
341
342
        num_attention_heads = query.shape[-2]
        softmax_offset = None
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
            softmax_offset = self.param(
                "softmax_offset",
                nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
                (1, num_attention_heads, 1, 1),
                jnp.float32,
            )

343
        if self.qkv_layout.is_qkvpacked():
344
345
346
347
348
349
350
351
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
352
353
            x = fused_attn(
                (qkv_packed,),
354
                bias,
355
                sequence_descriptor,
356
357
358
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
359
                qkv_layout=self.qkv_layout,
360
                softmax_type=self.softmax_type,
361
362
363
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
364
                window_size=self.window_size,
365
                max_segments_per_seq=self.max_segments_per_seq,
366
367
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
368
                context_parallel_strategy=self.context_parallel_strategy,
369
                context_checkpoint_name=self.context_checkpoint_name,
370
                softmax_offset=softmax_offset,
371
            )
372
        elif self.qkv_layout.is_kvpacked():
373
374
375
376
377
378
379
380
381
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
382
383
            x = fused_attn(
                (query, kv_packed),
384
                bias,
385
                sequence_descriptor,
386
387
388
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
389
                qkv_layout=self.qkv_layout,
390
                softmax_type=self.softmax_type,
391
392
393
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
394
                window_size=self.window_size,
395
                max_segments_per_seq=self.max_segments_per_seq,
396
397
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
398
                context_parallel_strategy=self.context_parallel_strategy,
399
                context_checkpoint_name=self.context_checkpoint_name,
400
                softmax_offset=softmax_offset,
401
            )
402
        elif self.qkv_layout.is_separate():
403
404
405
406
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
407
            x = fused_attn(
408
                (query, key, value),
409
                bias,
410
                sequence_descriptor,
411
412
413
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
414
                qkv_layout=self.qkv_layout,
415
                softmax_type=self.softmax_type,
416
417
418
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
419
                window_size=self.window_size,
420
                max_segments_per_seq=self.max_segments_per_seq,
421
422
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
423
                context_parallel_strategy=self.context_parallel_strategy,
424
                context_checkpoint_name=self.context_checkpoint_name,
425
                softmax_offset=softmax_offset,
426
            )
427
428
429
430
431
432
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

433
        assert x.dtype == query.dtype
434
435
436
        return x


437
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

453
454
455
456
        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
          attention kernel is not available on the system, a warning will be issued, and the module
          will automatically fall back to the unfused backend.
457

458
459
460
461
462
463
464
465
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

502
503
        .. note:: THD format only supports 'padding' or 'causal_padding' mask type.

504
505
506
507
508
509
510
511
       attn_mask_type       mask/sequence_descriptor       SWA          softmax type
       --------------------------------------------------------------------------------------------
       no_mask              None                           None         SCALED
       causal               None                           None         SCALED_UPPER_TRIANG_MASKED
       causal               None                           Yes          SCALED_MASKED
       padding              Required                       Yes/No       SCALED_MASKED
       padding_causal       Required                       Yes/No       SCALED_MASKED

512
    attn_bias_type: Optional[str], default = None
513
        Type of the attention bias passed in the attention.
514
515
516
517
518
519
520
521
522
523
524
525
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
526
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
527
528
529
530
531
532

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
533
534
        * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
          sequences to be packed in a batch, also known as sequence packing.
535
536
537
538
539
540
541
542
543
544
545
546

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
547
548
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
549
        Indicate whether the input tensors were switched axis of batch
550
        and sequence length dimension. If set to True, the input tensors
551
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
552
553
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. The default value is no sliding window.
554
555
    max_segments_per_seq: Optional[int], default = 1
        The maximum number of segments per sequence, also used for THD format (sequence packing).
556
557
558
    context_parallel_causal_load_balanced (bool):
            Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
    context_parallel_axis (str): The name of the context parallel axis.
559
    context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
560
    context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
561
562
563
564
565
566
567
568
569
570
571
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
        softmax type as described in this paper:
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
        For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
        'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
        'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
        'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
        where alpha is a learnable parameter in shape [h].
        'off-by-one' and 'learnable' softmax types are also called sink attention
        ('zero sink' and 'learnable sink').
572
573
574

    Optimization parameters
    -----------------------
575
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
576
        The data type used to allocate the initial parameters.
577
    """
578

579
580
581
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
582
583
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
584
585
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
586
    dropout_rng_name: str = "dropout"
587
    float32_logits: bool = False
588
    qkv_layout: str = "bshd_bshd_bshd"
589
    scale_factor: Optional[float] = None
590
    transpose_batch_sequence: bool | None = None
591
    window_size: Optional[Tuple[int, int]] = None
592
    max_segments_per_seq: Optional[int] = 1
593
594
    context_parallel_causal_load_balanced: bool = False
    context_parallel_axis: str = ""
595
    context_parallel_strategy: str = "DEFAULT"
596
    context_checkpoint_name: str = "context"
597
    softmax_type: str = "vanilla"
598

599
600
601
602
603
604
605
606
607
608
609
    def __post_init__(self):
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in DotProductAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
        super().__post_init__()

610
    @nn.compact
611
612
613
614
615
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
616
        sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
617
618
619
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
620
        mask: Optional[Union[SequenceDescriptor, Array]] = None,
621
    ) -> Array:
622
623
624
625
626
627
628
629
630
631
632
633
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
634
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
635
636
637
638
639
640
641
642
643
644
645
646
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """
647
        input_dtype = query.dtype
648

649
650
651
652
653
654
655
656
657
        if mask is not None:
            if sequence_descriptor is not None:
                raise ValueError(
                    "sequence_descriptor and mask cannot be provided at the same time."
                )
            warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
            sequence_descriptor = mask
            del mask

658
659
660
661
662
663
664
        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
665
        softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
666
667
668
669
670
671
672
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

673
674
        # Use fused attn (if kernel check below passes) by default
        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
675
676
677
678
679
680
681

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]
682
683
684
685
686
687
        if qkv_layout.is_separate():
            head_dim_qk = query.shape[-1]
            head_dim_v = value.shape[-1]
        else:
            head_dim_qk = self.head_dim
            head_dim_v = self.head_dim
688

689
        has_fused_attn_kernel = is_fused_attn_kernel_available(
690
691
            # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
            not deterministic,
692
693
694
695
696
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
697
            softmax_type,
698
699
700
701
702
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
703
704
            head_dim_qk,
            head_dim_v,
705
            self.window_size,
706
        )
707

708
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
709
710

        if enable_fused_attn and not has_fused_attn_kernel:
711
712
713
714
715
716
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
717
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
718
            )
719
720

        dropout_rng = None
721
        if not deterministic and self.attention_dropout > 0.0:
722
723
724
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
725
            scale_factor = 1.0 / sqrt(head_dim_qk)
726
727
728
729
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        # case-insensitive mapping for context parallel strategy
        cp_strategy_map = {
            "DEFAULT": CPStrategy.DEFAULT,
            "ALL_GATHER": CPStrategy.ALL_GATHER,
            "ALLGATHER": CPStrategy.ALL_GATHER,  # Alternative spelling
            "RING": CPStrategy.RING,
        }

        strategy_key = self.context_parallel_strategy.upper()
        if strategy_key in cp_strategy_map:
            context_parallel_strategy = cp_strategy_map[strategy_key]
        else:
            valid_strategies = list(cp_strategy_map.keys())
            raise ValueError(
                f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
                f"Valid options are: {valid_strategies} (case insensitive)"
            )

748
749
        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
750
            if qkv_layout.is_qkvpacked():
751
                query, key, value = jnp.split(query, [1, 2], axis=-3)
752
753
754
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
755
            elif qkv_layout.is_kvpacked():
756
757
758
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
759
760
                assert qkv_layout.is_separate()

761
762
763
            assert sequence_descriptor is None or isinstance(
                sequence_descriptor, (jnp.ndarray, np.ndarray)
            )
764

765
766
767
768
769
770
771
772
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
773
                window_size=self.window_size,
774
                softmax_type=softmax_type,
775
776
777
778
779
780
781
782
783
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
784
785
786
787
788
789
790
791
792
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
793
                window_size=self.window_size,
794
                max_segments_per_seq=self.max_segments_per_seq,
795
796
                context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
                context_parallel_axis=self.context_parallel_axis,
797
                context_parallel_strategy=context_parallel_strategy,
798
                context_checkpoint_name=self.context_checkpoint_name,
799
                softmax_type=softmax_type,
800
801
802
803
804
805
806
807
808
            )(
                query,
                key,
                value,
                sequence_descriptor,
                bias,
                dropout_rng=dropout_rng,
                deterministic=deterministic,
            )
809
        assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
810
        return x
811
812


813
814
815
816
817
818
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
819
820
821
    """
    Rotary Positional Embedding
    x should be in shape of
822
823
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
824
    """
825
826
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
827
828
829
    min_window = windows[0]
    max_window = windows[1]

830
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
831
    time_scales = min_window * (max_window / min_window) ** fraction
832
833
834
835
836
837
838
839
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

840
841
842
843
844
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
845

846
847
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
848

849
        x1, x2 = jnp.split(x, 2, axis=-1)
850
851
        part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
852

853
        output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
877
878
879
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
880
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
881
        )
882
883
884
885
886

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

887
    if group_method == "alternate":
888
889
        return alternate_impl()
    return consecutive_impl()
890
891


892
class LoRAScope:  # pylint: disable=too-few-public-methods
893
894
895
896
897
898
899
900
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
901
902
903
904
905
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
906
907
908
909


def _canonicalize_lora_scope(scope):

910
911
912
913
914
915
916
917
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
918
919
920
921
922
923

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
924
925
926
927
928
929
930
931
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


948
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
949
950
951
952
953
954
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
955
    head_dim: int
956
        The hidden dimension of each attention head.
957
958
959
960
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
961
962
963
964
965
966
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
967
968
969
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

989
990
991
992
993
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
994
    dropout_rng_name: str, default = 'dropout'
995
996
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
997
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
998
        Indicate the type of layer normalization.
999
    layernorm_epsilon: float, default = 1e-6
1000
        A value added to the denominator of layer normalization for numerical stability.
1001
    zero_centered_gamma: bool, default = False
1002
1003
1004
1005
1006
1007
1008
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1009
1010
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1011
        Used for initializing the QKV and output projection weights.
1012
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1013
    use_bias: bool, default = False
1014
        Indicate whether or not to enable bias shifting for QKV and output projections.
1015
        If set to False, the layer will not learn additive biases.
1016
    bias_init: Initializer, default = flax.linen.initializers.zeros
1017
1018
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1019
1020
1021
1022
1023
1024
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
1025
1026
1027
1028
1029
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1030
1031
1032
1033
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
1034
1035
1036
1037
1038
1039
1040
1041
1042
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
1043
1044
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1045
1046
1047
1048
1049
1050
1051
1052
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
1053
1054
1055

    Optimization parameters
    -----------------------
1056
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1057
        The data type used to allocate the initial parameters.
1058
    fuse_qkv_params: bool, default = True
1059
        If set to True, this module exposes a single fused
1060
1061
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1062
1063
    TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
    transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
1064
        Indicate whether the input tensors were switched axis of batch
1065
1066
1067
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1068
        Indicate whether to scale attention logits.
1069
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
1070
        else :math:`Q*K`
1071
1072
1073
1074
1075
1076
1077
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
1078
1079
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
        softmax type as described in this paper:
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
        For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
        'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
        'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
        'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
        where alpha is a learnable parameter in shape [h].
        'off-by-one' and 'learnable' softmax types are also called sink attention
        ('zero sink' and 'learnable sink').
1091
1092
1093
    """

    head_dim: int
1094
1095
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
1096
1097
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
1098
    input_layernorm: bool = True
1099
1100
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
1101
    return_layernorm_output: bool = False
1102
    zero_centered_gamma: bool = False
1103
1104
1105
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
1106
    attn_mask_type: str = "causal"
1107
    attn_bias_type: Optional[str] = None
1108
1109
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1110
1111
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1112
1113
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1114
    dtype: DType = jnp.float32
1115
    fuse_qkv_params: bool = True
1116
    transpose_batch_sequence: bool | None = None
1117
    enable_sequence_parallel: bool = False
1118
1119
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1120
    float32_logits: bool = False
1121
    window_size: Optional[Tuple[int, int]] = None
1122
    softmax_type: str = "vanilla"
1123
1124
1125
1126
1127
1128
1129

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
1130
1131

    def __post_init__(self):
1132
1133
1134
1135
1136
1137
1138
1139
1140
        # Deal with changed defaults in API
        # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
        # None implies that the user is relying on defaults, hence warn the user and set the new defaults
        if self.transpose_batch_sequence is None:
            warnings.warn(
                "transpose_batch_sequence defaults to False in MultiHeadAttention starting"
                " TransformerEngine v2.10"
            )
            self.transpose_batch_sequence = False
1141
1142
1143
1144
1145
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
1146
1147
1148
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
1149
1150
1151
1152
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
1153
1154
1155
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
1156
1157
1158
1159
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
1160
1161
                DeprecationWarning,
            )
1162
1163
1164
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
1165
1166
1167
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
1168
1169
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
1170
1171
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
1172

1173
        if self.kernel_init is None:
1174
            self.kernel_init = nn.initializers.variance_scaling(
1175
                1.0, "fan_in", "normal", dtype=self.dtype
1176
            )
zlsh80826's avatar
zlsh80826 committed
1177
        if self.num_gqa_groups is None:
1178
            self.num_gqa_groups = self.num_attention_heads
1179
1180
1181
        super().__post_init__()

    @nn.compact
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
1192
1193
1194
1195
1196
1197
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
1198
        inputs_q: jax.numpy.ndarray
1199
            Input tensor for query projection.
1200
        inputs_kv: jax.numpy.ndarray
1201
            Input tensor for key/value projection.
1202
1203
1204
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
1205
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
1206
1207
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
1208
        *
1209
        decode: bool, default = False
1210
            Indicate whether to prepare and use an autoregressive cache.
1211
        deterministic: bool, default = False
1212
1213
1214
1215
            Disable dropout layers if set to True.

        Returns
        -------
1216
        outputs: jax.numpy.ndarray
1217
1218
            Output tensors.
        """
1219

1220
1221
1222
1223
1224
        assert (
            inputs_q.dtype == inputs_kv.dtype
        ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
        input_dtype = inputs_q.dtype

1225
        def query_init(*args):
1226
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1269
1270
1271
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1272

1273
1274
1275
1276
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1277
1278
1279
1280
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1281
1282
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1283
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1284
            if is_qkvpack:
1285
                qkv_proj, ln_out = LayerNormDenseGeneral(
1286
                    enable_layernorm=self.input_layernorm,
1287
                    layernorm_type=self.layernorm_type,
1288
                    zero_centered_gamma=self.zero_centered_gamma,
1289
1290
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1291
1292
                    features=(3, self.num_attention_heads * self.head_dim),
                    return_layernorm_output=self.return_layernorm_output,
1293
1294
1295
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1296
1297
1298
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1299
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1300
1301
1302
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1303
1304
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1305
                    transpose_batch_sequence=self.transpose_batch_sequence,
1306
1307
1308
1309
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1310
                qkv_layout = QKVLayout.BS3HD
1311
1312
            else:
                query, ln_out = LayerNormDenseGeneral(
1313
                    enable_layernorm=self.input_layernorm,
1314
                    layernorm_type=self.layernorm_type,
1315
                    zero_centered_gamma=self.zero_centered_gamma,
1316
1317
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1318
1319
                    features=self.num_attention_heads * self.head_dim,
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1320
1321
1322
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1323
1324
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1325
                    bias_axes=(W_TP_AXES,),
1326
1327
1328
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1329
1330
                    dtype=self.dtype,
                    kernel_init=query_init,
1331
1332
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1333
                    transpose_batch_sequence=self.transpose_batch_sequence,
1334
1335
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1336
1337
1338
1339
1340

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1352
                    transpose_batch_sequence=self.transpose_batch_sequence,
1353
1354
1355
1356
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1357
                qkv_layout = QKVLayout.BSHD_BS2HD
1358
1359
1360
1361
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1362
                features=self.num_gqa_groups * self.head_dim,
1363
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1364
1365
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1366
                bias_axes=(W_TP_AXES,),
1367
1368
1369
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1370
1371
                dtype=self.dtype,
            )
1372
            query, ln_out = LayerNormDenseGeneral(
1373
                enable_layernorm=self.input_layernorm,
1374
                layernorm_type=self.layernorm_type,
1375
                zero_centered_gamma=self.zero_centered_gamma,
1376
1377
                epsilon=self.layernorm_epsilon,
                axis=-1,
1378
                features=self.num_attention_heads * self.head_dim,
1379
                return_layernorm_output=True,
1380
1381
1382
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1383
1384
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1385
                bias_axes=(W_TP_AXES,),
1386
1387
1388
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1389
1390
                dtype=self.dtype,
                kernel_init=query_init,
1391
1392
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1393
                transpose_batch_sequence=self.transpose_batch_sequence,
1394
1395
                name="query",
            )(inputs_q)
1396

1397
            if is_self_attn:
1398
1399
1400
                assert ln_out is not None
                inputs_kv = ln_out

1401
            query = query.astype(input_dtype)
1402
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
1403
            key = key.astype(input_dtype)
1404
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
1405
            value = value.astype(input_dtype)
1406
1407
1408
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1409
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1410

1411
        if self.enable_rotary_pos_emb:
1412
1413
1414
1415
1416
1417
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1418

1419
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1420
1421
1422
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1435
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1436

1437
1438
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1439
1440
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1441
1442

        if decode:
1443
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1444
1445
1446
1447
1448
1449
1450
1451
1452
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1453
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1454
                if self.transpose_batch_sequence:
1455
1456
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1457
1458
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1459
1460
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1461
                    one_hot_indices_shape = (1, length, 1, 1)
1462
1463
1464
1465

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1466
1467
1468
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1469

1470
                cur_index = cache_index.value.astype(jnp.int32)
1471
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1472
1473
1474
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1475
1476
1477
1478
1479
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1480
1481
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1482
1483

                if bias is not None:
1484
1485
1486
1487
1488
1489
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1490

1491
1492
1493
1494
1495
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1496
1497
1498
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1510
        else:
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1521
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
1535
            window_size=self.window_size,
1536
            softmax_type=self.softmax_type,
1537
        )(*dpa_args, mask, bias, deterministic=deterministic)
1538
1539
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1540
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1541
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1542

1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1558

1559
1560
1561
        assert (
            inputs_q.dtype == out.dtype
        ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
1562
        return out, ln_out
1563
1564


1565
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1566
1567
1568
1569
1570
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1571
    num_buckets: int
1572
        The number of buckets to bucket distances between key and query positions into.
1573
    max_distance: int
1574
        The maximum distance before everything is lumped into the last
1575
        distance bucket.
1576
    num_attention_heads: int
1577
        Number of attention heads in the transformer layer.
1578
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1579
        Used for initializing relative embedding tables.
1580
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1581
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1582
1583
1584

    Optimization parameters
    -----------------------
1585
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1586
        The data type used to allocate the initial parameters.
1587
    """
1588

1589
1590
1591
1592
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1593
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1594
1595
1596
1597
1598
1599
1600
1601
1602
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1603
        q_seqlen: int
1604
            The sequence length of query.
1605
        k_seqlen: int
1606
            The sequence length of key.
1607
        bidirectional: bool, default = True
1608
            Indicate whether to allow positive memory-query relative position
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1635
1636
1637
1638
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1639
1640
1641
1642
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
1643
        relative_attention_bias = self.param(
1644
            "rel_embedding",
1645
            nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
1646
            (self.num_attention_heads, self.num_buckets),
1647
            self.dtype,
1648
        )
1649
1650
1651
1652
1653
1654

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1655
1656
1657
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1658
1659
1660
1661
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1672

1673
1674
1675
1676
    ENCODER = "encoder"
    DECODER = "decoder"


1677
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1678
1679
1680
1681
1682
1683
1684
1685
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1686
        The hidden size of each input sample.
1687
    mlp_hidden_size: int, default = 2048
1688
        Intermediate size to which input samples are projected.
1689
    num_attention_heads: int, default = 8
1690
        Number of attention heads in the transformer layer.
1691
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1692
1693
1694
1695
1696
1697
1698
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1699
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1700
        Indicate the type of layer normalization.
1701
    layernorm_epsilon: float, default = 1e-6
1702
        A value added to the denominator of layer normalization for numerical stability.
1703
    zero_centered_gamma: bool, default = False
1704
1705
1706
1707
1708
1709
1710
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1711
    hidden_dropout: float, default = 0.1
1712
        Dropout probability for the dropout op after FC2 layer.
1713
    hidden_dropout_dims: Sequence[int], default = ()
1714
        Dimensions that will share the same dropout mask for hidden
1715
    attention_dropout: float, default = 0.1
1716
        Dropout probability for the dropout op during multi-head attention.
1717
    intermediate_dropout: float, default = 0.0
1718
1719
1720
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1721
    dropout_rng_name: str, default = 'dropout'
1722
1723
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1724
1725
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1726
1727
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1728
1729
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1730
1731
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1732
    mlp_activations: Sequence[str], default = ('gelu', )
1733
        The sequence of activation functions to apply after the first linear transformation.
1734
        Each activation has its own transformation layer.
1735
1736
1737
    mlp_activation_params: dict = None
         This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
        ClampedSwiglu is the only activation that requires parameters.
1738
    use_bias: bool, default = False
1739
1740
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1741
    bias_init: Initializer, default = flax.linen.initializers.zeros
1742
1743
1744
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1745
    apply_residual_connection_post_layernorm: bool, default = False
1746
        If set to True, residual connections are taken from the output
1747
1748
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1749
        If set to True, layer normalization is applied on the output side,
1750
1751
1752
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1753
1754
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1755
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1756
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1757
1758
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1759
    self_attn_mask_type: str, default = 'causal'
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1779
1780
1781
1782
1783
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1784
    enable_relative_embedding: bool, default = True
1785
        Whether to enable relative embedding as shifting of attention logits.
1786
    relative_embedding: flax.linen.Module, default = None
1787
        The module for relative embedding execution, only used when
1788
1789
1790
1791
1792
1793
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1794
1795
1796
1797
1798
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1799
    rotary_pos_emb_group_method: str, default = 'consecutive'
1800
1801
1802
1803
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1804
1805
1806
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1807
        'exclude_output_proj', 'exclude_mlp']
1808
1809
1810
1811
1812
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1813
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1814
1815
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1816
1817
    window_size: Optional[Tuple[int, int]], default = None
        Sliding window size. Default value is no sliding window.
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
    softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
        Softmax type as described in this paper:
        `Efficient Streaming Language Models with Attention Sinks
        <https://arxiv.org/pdf/2309.17453v3>`_.
        For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
        'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
        'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
        'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
        where alpha is a learnable parameter in shape [h].
        'off-by-one' and 'learnable' softmax types are also called sink attention
        ('zero sink' and 'learnable sink').
        Only supported for fused attention backend.
1830
1831
1832

    Optimization parameters
    -----------------------
1833
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1834
        The data type used to allocate the initial parameters.
1835
    drop_path: float, default = 0.0
1836
        When > 0.0, applies stochastic depth per sample in the main
1837
1838
        path of the residual block.
    fuse_qkv_params: bool, default = True
1839
        If set to True, `TransformerLayer` module exposes a single fused
1840
1841
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1842
    transpose_batch_sequence: bool, default = False
1843
        Indicate whether the input tensors were switched axis of batch
1844
1845
1846
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1847
        Indicate whether to scale attention logits.
1848
1849
1850
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1851
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1852
1853
1854
1855
1856
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1857
    num_gqa_groups: Optional[int] = None
1858
    layernorm_type: str = "layernorm"
1859
    layernorm_epsilon: float = 1e-6
1860
    zero_centered_gamma: bool = False
1861
1862
1863
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1864
    intermediate_dropout: float = 0.0
1865
    intermediate_dropout_dims: Sequence[int] = ()
1866
    dropout_rng_name: str = "dropout"
1867
1868
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1869
    mlp_activations: Sequence[str] = ("gelu",)
1870
    mlp_activation_params: dict = None
1871
1872
1873
1874
1875
1876
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1877
    self_attn_mask_type: str = "causal"
1878
    self_attn_bias_type: Optional[str] = None
1879
1880
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1881
1882
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1883
1884
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1885
1886
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1887
1888
1889
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1890
    transpose_batch_sequence: bool = False
1891
    enable_sequence_parallel: bool = False
1892
1893
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
1894
    window_size: Optional[Tuple[int, int]] = None
1895
    softmax_type: str = "vanilla"
1896
1897
1898

    def __post_init__(self):
        if self.mha_kernel_init is None:
1899
            self.mha_kernel_init = nn.initializers.variance_scaling(
1900
                1.0, "fan_in", "normal", dtype=self.dtype
1901
            )
1902
        if self.mlp_kernel_init is None:
1903
            self.mlp_kernel_init = nn.initializers.variance_scaling(
1904
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1905
            )
zlsh80826's avatar
zlsh80826 committed
1906
1907
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1908
1909
1910
        super().__post_init__()

    @nn.compact
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1921
1922
1923
1924
1925
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1926
        inputs: jax.numpy.ndarray
1927
            Input tensor.
1928
        encoded: jax.numpy.ndarray, default = None
1929
1930
1931
1932
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1933
1934
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1935
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1936
1937
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1938
            :attr:`True` means mask out the corresponding values.
1939
        deterministic: bool, default = False
1940
            Disable dropout layers if set to True.
1941
        decode: bool, default = False
1942
1943
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1944
        max_decode_length: bool, default = None
1945
1946
1947
1948
1949
1950
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1951
        outputs: jax.numpy.ndarray
1952
            Output tensors.
1953
        """
1954

1955
        input_dtype = inputs.dtype
1956
1957
1958
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1959

1960
1961
1962
1963
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1964

1965
1966
1967
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1968
1969
1970
1971
1972
1973

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1974
1975
1976
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1977
1978
1979
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1980
1981
1982
1983
1984

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1985
1986
1987
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1988
1989
1990
1991
1992
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
1993
1994
1995
                    embedding_init=nn.initializers.variance_scaling(
                        1.0, "fan_avg", "uniform", dtype=self.dtype
                    ),
1996
1997
                    name="relpos_bias",
                )
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
2015
            mha_name = "attention"
2016
        else:
2017
            mha_name = "self_attention"
2018

2019
        inputs = with_sharding_constraint_by_logical_axes(
2020
2021
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2022

2023
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
2024
2025
2026
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
2027
2028
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2029
            num_gqa_groups=self.num_gqa_groups,
2030
            transpose_batch_sequence=self.transpose_batch_sequence,
2031
            enable_sequence_parallel=self.enable_sequence_parallel,
2032
            attention_dropout=self.attention_dropout,
2033
2034
2035
2036
2037
2038
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
2039
            zero_centered_gamma=self.zero_centered_gamma,
2040
2041
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
2042
            attn_mask_type=self.self_attn_mask_type,
2043
            attn_bias_type=self.self_attn_bias_type,
2044
2045
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2046
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2047
2048
2049
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2050
            fuse_qkv_params=self.fuse_qkv_params,
2051
2052
2053
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2054
            name=mha_name,
2055
            window_size=self.window_size,
2056
            softmax_type=self.softmax_type,
2057
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
2058
2059
2060
2061
2062

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
2063
                assert -x_shape_len <= dims < x_shape_len
2064

2065
2066
2067
2068
2069
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2070

2071
        x = with_sharding_constraint_by_logical_axes(
2072
2073
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2074
        residual = with_sharding_constraint_by_logical_axes(
2075
2076
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2077

2078
2079
2080
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
2081
2082
2083
2084
2085
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
2086
2087
2088
2089
2090

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2091
2092
2093
2094
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
2095
2096
2097
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
2098

2099
            x = with_sharding_constraint_by_logical_axes(
2100
2101
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2102

2103
2104
2105
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
2106
2107
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
2108
                num_gqa_groups=self.num_gqa_groups,
2109
                transpose_batch_sequence=self.transpose_batch_sequence,
2110
                enable_sequence_parallel=self.enable_sequence_parallel,
2111
                attention_dropout=self.attention_dropout,
2112
2113
2114
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
2115
                zero_centered_gamma=self.zero_centered_gamma,
2116
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
2117
2118
2119
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
2120
2121
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
2122
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
2123
2124
2125
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2126
2127
2128
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
2129
                fuse_qkv_params=self.fuse_qkv_params,
2130
2131
2132
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
2133
                name="encoder_decoder_attention",
2134
                window_size=self.window_size,
2135
                softmax_type=self.softmax_type,
2136
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
2137
2138

            y = with_sharding_constraint_by_logical_axes(
2139
2140
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2141
            residual = with_sharding_constraint_by_logical_axes(
2142
2143
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
2144

2145
            y = hidden_dropout(y, deterministic)
2146
2147
2148
2149
2150

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

2151
2152
            mlp_input = y + residual

2153
        mlp_input = with_sharding_constraint_by_logical_axes(
2154
2155
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2156

2157
2158
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

2159
2160
2161
2162
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
2163
            zero_centered_gamma=self.zero_centered_gamma,
2164
2165
2166
2167
            epsilon=self.layernorm_epsilon,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
2168
            activation_params=self.mlp_activation_params,
2169
2170
2171
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
2172
            dtype=self.dtype,
2173
2174
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
2175
            kernel_init=self.mlp_kernel_init,
2176
2177
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
2178
2179
            use_bias=self.use_bias,
            bias_init=self.bias_init,
2180
2181
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
2182
2183
2184
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
2185
2186
2187
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
2188
            transpose_batch_sequence=self.transpose_batch_sequence,
2189
            name="mlp",
2190
2191
2192
2193
2194
2195
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

2196
        z = with_sharding_constraint_by_logical_axes(
2197
2198
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2199
        residual = with_sharding_constraint_by_logical_axes(
2200
2201
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
2202

2203
2204
2205
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
2206
2207
2208
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
2209
2210
2211
        z = z + residual

        if self.output_layernorm:
2212
            z = with_sharding_constraint_by_logical_axes(
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
2224
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
2225
        return z