transformer.py 84.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
19
from flax.linen.attention import combine_masks
20
21
22
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
23
from jax.ad_checkpoint import checkpoint_name
24
25
26

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
27
28
from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type
29
from ..attention import fused_attn
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
42
43
44
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
45
46
47
48
49
50
51
52
53
54
55
56
57
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
58
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
59
60
61
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
62
63
64
65
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
66
67
68
69

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
70
71
72
73
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

74
75
    Parameters
    ----------
76
    rules: Sequence[Tuple[str, Union[str, None]]]
77
78
79
80
        the base Flax logical axis rules to extend.

    Returns
    -------
81
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
82
83
84
85
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
86
        assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)."
87
88
        key = item[0]
        val = item[1]
89
90
91
92
        assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (
            val is None
        ), f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
93
94
95
96
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
97
98

    extended_rules = [*rules]
99
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
100
101
102
        key = item[0]
        val = item[1]
        if key in rules_map:
103
104
105
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, (
                "The rule diverged between TE and given rule."
                f"Axis:{key} map to {rules_map[key]} in the given"
106
                f" rules, but {val} in TE's rules."
107
            )
108
109
110
111
112
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


113
114
class _UnfusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
115
116
117
118
119
120
121
122
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    @nn.compact
123
124
125
126
127
128
129
130
131
132
133
134
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
135
        batch_dim = 1 if self.transpose_batch_sequence else 0
136
137
138
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
139
        sequence_dim = 0 if self.transpose_batch_sequence else 1
140
141
142
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
143
144
145
146
147
148
149
150
151
152
153
154
155

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
156
        is_gqa = h_q != h_kv
157

158
        if is_gqa:
159
160
161
162
163
164
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
165
                attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
166
            else:
167
                attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key)
168
        else:
169
            if is_gqa:
170
                attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
171
            else:
172
                attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
173

174
        attn_weights = checkpoint_name(attn_weights, "logits")
175

176
        if is_gqa:
177
178
179
180
181
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        attn_weights = with_sharding_constraint_by_logical_axes(
182
183
            attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
        )
184
185
186
187
188

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
189
            fused_scale_factor = 1.0
190
        else:
191
192
193
194
195
196
197
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
198
199
200
            # mask is ignored for no_mask and causal_mask
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
                mask = None
201
            if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
202
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
203
204
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
                if mask is not None:
205
206
                    return SoftmaxType.SCALED_MASKED, mask
                return SoftmaxType.SCALED, mask
207
208
209
210
            raise ValueError(
                f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
            )
211

212
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
213

214
215
216
        attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
            attn_weights, mask, bias
        ).astype(self.dtype)
217

218
        if is_gqa:
219
220
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

221
        if not deterministic and self.attention_dropout > 0.0:
222
223
224
225
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
226
            multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
227
228
229
230
            attn_weights = attn_weights * multiplier

        if self.transpose_batch_sequence:
            if is_gqa:
231
232
                return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
            return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value)
233
234

        if is_gqa:
235
236
            return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
        return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
237
238


239
240
class _FusedDotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.0
241
242
243
244
245
246
247
248
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False

    @nn.compact
249
250
251
252
253
254
255
256
257
258
259
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        dropout_rng: Optional[PRNGKey] = None,
        deterministic: bool = False,
    ) -> Array:
260
261
262
263
264
265
266
267
268
269
270

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

271
        # TODO(rewang): integrate THD format
272
273
274
275
276
277
278
279
280
        if self.qkv_layout == QKVLayout.BS3HD:
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
281
282
            x = fused_attn(
                (qkv_packed,),
283
284
285
286
287
                bias,
                mask,
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
288
                qkv_layout=self.qkv_layout,
289
290
291
292
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
            )
293
294
295
296
297
298
299
300
301
302
        elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
303
304
            x = fused_attn(
                (query, kv_packed),
305
306
307
308
309
                bias,
                mask,
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
310
                qkv_layout=self.qkv_layout,
311
312
313
314
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
            )
315
316
317
318
319
        elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
320
            x = fused_attn(
321
                (query, key, value),
322
323
324
325
326
                bias,
                mask,
                seed,
                attn_mask_type=self.attn_mask_type,
                attn_bias_type=self.attn_bias_type,
327
                qkv_layout=self.qkv_layout,
328
329
330
331
                scaling_factor=scale_factor,
                dropout_probability=self.attention_dropout,
                is_training=not deterministic,
            )
332
333
334
335
336
337
338
339
340
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

        return x


341
class DotProductAttention(nn.Module):  # pylint: disable=too-few-public-methods
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

362
363
364
365
366
367
368
369
    .. note::
        The DotProductAttention default setting enables non-deterministic kernels for reduced
        workspace requirements and faster computation. Users can disable the non-deterministic
        kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:

        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
        * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

406
    attn_bias_type: Optional[str], default = None
407
        Type of the attention bias passed in the attention.
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).

    Optimization parameters
    -----------------------
    dtype: jax.numpy.dtype, default = jax.numpy.float32
        The data type used to allocate the initial parameters.
    """
449

450
451
452
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
453
454
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
455
456
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
457
    dropout_rng_name: str = "dropout"
458
    float32_logits: bool = False
459
    qkv_layout: str = "bshd_bshd_bshd"
460
461
462
463
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    @nn.compact
464
465
466
467
468
469
470
471
472
473
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        deterministic: bool = False,
    ) -> Array:
474
475
476
477
478
479
480
481
482
483
484
485
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
486
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """

        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]

523
524
525
526
527
528
529
530
531
532
533
534
535
        has_fused_attn_kernel = is_fused_attn_kernel_available(
            self.dtype,
            self.dtype,
            qkv_layout,
            attn_bias_type,
            attn_mask_type,
            self.attention_dropout,
            self.num_attention_heads,
            self.num_gqa_groups,
            seqlen_q,
            seqlen_kv,
            self.head_dim,
        )
536

537
        use_fused_attn = enable_fused_attn and has_fused_attn_kernel
538
539

        if enable_fused_attn and not has_fused_attn_kernel:
540
541
542
543
544
545
546
547
            warnings.warn(
                "Fused attention is not enabled because there is no available kernel.\n"
                "Fall back to the unfused attention.\n"
                "Please try to update the cuDNN and TE to the latest version.\n"
                f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
                f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n"
            )
548
549

        dropout_rng = None
550
        if not deterministic and self.attention_dropout > 0.0:
551
552
553
554
555
556
557
558
559
560
561
562
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(self.head_dim)
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(query, [1, 2], axis=-3)
563
564
565
                query, key, value = map(
                    functools.partial(jnp.squeeze, axis=-3), [query, key, value]
                )
566
567
568
569
570
571
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD

572
573
574
575
576
577
578
579
580
            x = _UnfusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                float32_logits=self.float32_logits,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
            )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
581
582
583
584
585
586
587
588
589
590
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
            )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
591

592
        return x
593
594


595
596
597
598
599
600
def rotary_pos_emb(
    x: Array,
    windows: Tuple[int, int],
    transpose_batch_sequence: bool,
    group_method: str = "consecutive",
):
601
602
603
    """
    Rotary Positional Embedding
    x should be in shape of
604
605
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
606
    """
607
608
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
609
610
611
    min_window = windows[0]
    max_window = windows[1]

612
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
613
    time_scales = min_window * (max_window / min_window) ** fraction
614
615
616
617
618
619
620
621
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

622
623
624
625
626
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
627

628
629
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
630

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        x1, x2 = jnp.split(x, 2, axis=-1)
        part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(x.dtype)

        output = jnp.concatenate([part_1, part_2], axis=-1)
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
659
660
661
        canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "")
        assert canonicalized_gm in ["consecutive", "alternate"], (
            "Invalid relative positional embedding group method. "
662
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
663
        )
664
665
666
667
668

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

669
    if group_method == "alternate":
670
671
        return alternate_impl()
    return consecutive_impl()
672
673


674
class LoRAScope:  # pylint: disable=too-few-public-methods
675
676
677
678
679
680
681
682
    """LoRA Scope"""

    def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
        self.qkv_proj = qkv_proj
        self.output_proj = output_proj
        self.mlp = mlp

    def __eq__(self, other):
683
684
685
686
687
        return (self.qkv_proj, self.output_proj, self.mlp) == (
            other.qkv_proj,
            other.output_proj,
            other.mlp,
        )
688
689
690
691


def _canonicalize_lora_scope(scope):

692
693
694
695
696
697
698
699
    SCOPE_NONE = "none"
    SCOPE_ALL = "all"
    SCOPE_QKV_PROJ = "qkv_proj"
    SCOPE_OUTPUT_PROJ = "output_proj"
    SCOPE_MLP = "mlp"
    SCOPE_EX_QKV_PROJ = "exclude_qkv_proj"
    SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj"
    SCOPE_EX_MLP = "exclude_mlp"
700
701
702
703
704
705

    scope = SCOPE_NONE if scope is None else scope

    scope = scope.lower()

    assert scope in [
706
707
708
709
710
711
712
713
        SCOPE_NONE,
        SCOPE_ALL,
        SCOPE_QKV_PROJ,
        SCOPE_OUTPUT_PROJ,
        SCOPE_MLP,
        SCOPE_EX_QKV_PROJ,
        SCOPE_EX_OUTPUT_PROJ,
        SCOPE_EX_MLP,
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
    ]

    lora_scope = LoRAScope()

    if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
        lora_scope.qkv_proj = True

    if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
        lora_scope.output_proj = True

    if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
        lora_scope.mlp = True

    return lora_scope


730
class MultiHeadAttention(nn.Module):  # pylint: disable=too-few-public-methods
731
732
733
734
735
736
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
737
    head_dim: int
738
        The hidden dimension of each attention head.
739
740
741
742
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
743
744
745
746
747
748
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
749
750
751
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

771
772
773
774
775
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
776
    dropout_rng_name: str, default = 'dropout'
777
778
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
779
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
780
        Indicate the type of layer normalization.
781
    layernorm_epsilon: float, default = 1e-6
782
        A value added to the denominator of layer normalization for numerical stability.
783
    zero_centered_gamma: bool, default = False
784
785
786
787
788
789
790
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
791
792
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
793
        Used for initializing the QKV and output projection weights.
794
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
795
    use_bias: bool, default = False
796
        Indicate whether or not to enable bias shifting for QKV and output projections.
797
        If set to False, the layer will not learn additive biases.
798
    bias_init: Initializer, default = flax.linen.initializers.zeros
799
800
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
801
802
803
804
805
806
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
807
808
809
810
811
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
812
813
814
815
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
816
817
818
819
820
821
822
823
824
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
825
826
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
827
828
829
830
831
832
833
834
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
835
836
837

    Optimization parameters
    -----------------------
838
    dtype: jax.numpy.dtype, default = jax.numpy.float32
839
        The data type used to allocate the initial parameters.
840
    fuse_qkv_params: bool, default = True
841
        If set to True, this module exposes a single fused
842
843
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
844
    transpose_batch_sequence: bool, default = True
845
        Indicate whether the input tensors were switched axis of batch
846
847
848
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
849
        Indicate whether to scale attention logits.
850
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
851
        else :math:`Q*K`
852
853
854
855
856
857
858
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
859
860
861
    """

    head_dim: int
862
863
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
864
865
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
866
    input_layernorm: bool = True
867
868
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
869
    return_layernorm_output: bool = False
870
    zero_centered_gamma: bool = False
871
872
873
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
874
    attn_mask_type: str = "causal"
875
    attn_bias_type: Optional[str] = None
876
877
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
878
879
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
880
881
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
882
    dtype: DType = jnp.float32
883
    fuse_qkv_params: bool = True
884
    transpose_batch_sequence: bool = True
885
    enable_sequence_parallel: bool = False
886
887
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
888
889
890
891
892
893
894
895
    float32_logits: bool = False

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
896
897

    def __post_init__(self):
898
899
900
901
902
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
903
904
905
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
906
907
908
909
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
910
911
912
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
913
914
915
916
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
917
918
                DeprecationWarning,
            )
919
920
921
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
922
923
924
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
925
926
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
927
928
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
929

930
        if self.kernel_init is None:
931
            self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
zlsh80826's avatar
zlsh80826 committed
932
        if self.num_gqa_groups is None:
933
            self.num_gqa_groups = self.num_attention_heads
934
935
936
        super().__post_init__()

    @nn.compact
937
938
939
940
941
942
943
944
945
946
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
947
948
949
950
951
952
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
953
        inputs_q: jax.numpy.ndarray
954
            Input tensor for query projection.
955
        inputs_kv: jax.numpy.ndarray
956
            Input tensor for key/value projection.
957
958
959
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
960
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
961
962
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
963
        *
964
        decode: bool, default = False
965
            Indicate whether to prepare and use an autoregressive cache.
966
        deterministic: bool, default = False
967
968
969
970
            Disable dropout layers if set to True.

        Returns
        -------
971
        outputs: jax.numpy.ndarray
972
973
            Output tensors.
        """
974
975

        def query_init(*args):
976
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

1019
1020
1021
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_attention_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
1022

1023
1024
1025
1026
        inputs_logical_axes_maybe_sp = (
            *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel),
            HIDDEN_AXES,
        )
1027
1028
1029
1030
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

1031
1032
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1033
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
1034
            if is_qkvpack:
1035
                qkv_proj, ln_out = LayerNormDenseGeneral(
1036
                    enable_layernorm=self.input_layernorm,
1037
                    layernorm_type=self.layernorm_type,
1038
                    zero_centered_gamma=self.zero_centered_gamma,
1039
1040
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1041
                    features=(3, self.num_attention_heads * self.head_dim),
1042
                    transpose_batch_sequence=self.transpose_batch_sequence,
1043
                    return_layernorm_output=self.return_layernorm_output,
1044
1045
1046
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
1047
1048
1049
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1050
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
1051
1052
1053
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1054
1055
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1056
1057
1058
1059
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_q)
                qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
1060
                qkv_layout = QKVLayout.BS3HD
1061
1062
            else:
                query, ln_out = LayerNormDenseGeneral(
1063
                    enable_layernorm=self.input_layernorm,
1064
                    layernorm_type=self.layernorm_type,
1065
                    zero_centered_gamma=self.zero_centered_gamma,
1066
1067
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
1068
                    features=self.num_attention_heads * self.head_dim,
1069
                    transpose_batch_sequence=self.transpose_batch_sequence,
1070
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
1071
1072
1073
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1074
1075
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
1076
                    bias_axes=(W_TP_AXES,),
1077
1078
1079
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1080
1081
                    dtype=self.dtype,
                    kernel_init=query_init,
1082
1083
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
1084
1085
                    name="query",
                )(inputs_q)
zlsh80826's avatar
zlsh80826 committed
1086
1087
1088
1089
1090

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
                kv_proj = DenseGeneral(
                    axis=-1,
                    features=(2, self.num_gqa_groups * self.head_dim),
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                    kernel_init=kv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
                    enable_low_rank_adaptation=lora_scope.qkv_proj,
                    low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                    low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
                kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
1107
                qkv_layout = QKVLayout.BSHD_BS2HD
1108
1109
1110
1111
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
1112
                features=self.num_gqa_groups * self.head_dim,
1113
                transpose_batch_sequence=self.transpose_batch_sequence,
1114
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1115
1116
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1117
                bias_axes=(W_TP_AXES,),
1118
1119
1120
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1121
1122
                dtype=self.dtype,
            )
1123
            query, ln_out = LayerNormDenseGeneral(
1124
                enable_layernorm=self.input_layernorm,
1125
                layernorm_type=self.layernorm_type,
1126
                zero_centered_gamma=self.zero_centered_gamma,
1127
1128
                epsilon=self.layernorm_epsilon,
                axis=-1,
1129
                features=self.num_attention_heads * self.head_dim,
1130
1131
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
1132
1133
1134
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1135
1136
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1137
                bias_axes=(W_TP_AXES,),
1138
1139
1140
                enable_low_rank_adaptation=lora_scope.qkv_proj,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1141
1142
                dtype=self.dtype,
                kernel_init=query_init,
1143
1144
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1145
1146
                name="query",
            )(inputs_q)
1147

1148
            if is_self_attn:
1149
1150
1151
                assert ln_out is not None
                inputs_kv = ln_out

1152
1153
1154
1155
1156
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
            query = checkpoint_name(query, "query_proj")
            key = checkpoint_name(key, "key_proj")
            value = checkpoint_name(value, "value_proj")
1157
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1158

1159
        if self.enable_rotary_pos_emb:
1160
1161
1162
1163
1164
1165
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1166

1167
            # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
1168
1169
1170
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
            query = rotary_pos_emb(
                query,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
            key = rotary_pos_emb(
                key,
                self.rotary_pos_emb_windows,
                self.transpose_batch_sequence,
                self.rotary_pos_emb_group_method,
            )
1183
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1184

1185
1186
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1187
1188
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1189
1190

        if decode:
1191
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1192
1193
1194
1195
1196
1197
1198
1199
1200
            is_initialized = self.has_variable("cache", "cached_key")

            cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
1201
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1202
                if self.transpose_batch_sequence:
1203
1204
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1205
1206
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1207
1208
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1209
                    one_hot_indices_shape = (1, length, 1, 1)
1210
1211
1212
1213

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
1214
1215
1216
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
1217
1218
1219

                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1220
1221
1222
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1223
1224
1225
1226
1227
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
1228
1229
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length))
                )
1230
1231

                if bias is not None:
1232
1233
1234
1235
1236
1237
                    dynamic_vector_slice_in_dim = vmap(
                        lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)
                    )
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
1238

1239
1240
1241
1242
1243
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
1244
1245
1246
            qkv_proj = qkv_proj.reshape(
                *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim
            )
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1258
        else:
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

1269
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
        x = DotProductAttention(
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=qkv_layout.name,
            scale_factor=scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
        )(*dpa_args, mask, bias, deterministic=deterministic)
1284
1285
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1286
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1287
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1288

1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
        out = DenseGeneral(
            features=inputs_q.shape[-1],
            transpose_batch_sequence=self.transpose_batch_sequence,
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=(W_TP_AXES, W_FSDP_AXES),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            bias_axes=(W_NO_SHARD_AXES,),
            enable_low_rank_adaptation=lora_scope.output_proj,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
            dtype=self.dtype,
            name="out",
        )(x)
        out = checkpoint_name(out, "out_proj")
1305
1306

        return out, ln_out
1307
1308


1309
class RelativePositionBiases(nn.Module):  # pylint: disable=too-few-public-methods
1310
1311
1312
1313
1314
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1315
    num_buckets: int
1316
        The number of buckets to bucket distances between key and query positions into.
1317
    max_distance: int
1318
        The maximum distance before everything is lumped into the last
1319
        distance bucket.
1320
    num_attention_heads: int
1321
        Number of attention heads in the transformer layer.
1322
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1323
        Used for initializing relative embedding tables.
1324
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1325
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1326
1327
1328

    Optimization parameters
    -----------------------
1329
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1330
        The data type used to allocate the initial parameters.
1331
    """
1332

1333
1334
1335
1336
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
1337
    embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
1338
1339
1340
1341
1342
1343
1344
1345
1346
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1347
        q_seqlen: int
1348
            The sequence length of query.
1349
        k_seqlen: int
1350
            The sequence length of key.
1351
        bidirectional: bool, default = True
1352
            Indicate whether to allow positive memory-query relative position
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
1379
1380
1381
1382
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps)
            / np.log(self.max_distance / rpb_max_exact)
            * (rpb_num_buckets - rpb_max_exact)
        ).astype(np.int32)
1383
1384
1385
1386
1387
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
1388
1389
1390
            "rel_embedding",
            self.embedding_init,
            (self.num_attention_heads, self.num_buckets),
1391
            jnp.float32,
1392
1393
            axes=self.embedding_axes,
        )
1394
1395
1396
1397
1398
1399

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

1400
1401
1402
        values = lax.dot_general(
            relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ()))
        )
1403
1404
1405
1406
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1417

1418
1419
1420
1421
    ENCODER = "encoder"
    DECODER = "decoder"


1422
class TransformerLayer(nn.Module):  # pylint: disable=too-few-public-methods
1423
1424
1425
1426
1427
1428
1429
1430
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1431
        The hidden size of each input sample.
1432
    mlp_hidden_size: int, default = 2048
1433
        Intermediate size to which input samples are projected.
1434
    num_attention_heads: int, default = 8
1435
        Number of attention heads in the transformer layer.
1436
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1437
1438
1439
1440
1441
1442
1443
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1444
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1445
        Indicate the type of layer normalization.
1446
    layernorm_epsilon: float, default = 1e-6
1447
        A value added to the denominator of layer normalization for numerical stability.
1448
    zero_centered_gamma: bool, default = False
1449
1450
1451
1452
1453
1454
1455
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1456
    hidden_dropout: float, default = 0.1
1457
        Dropout probability for the dropout op after FC2 layer.
1458
    hidden_dropout_dims: Sequence[int], default = ()
1459
        Dimensions that will share the same dropout mask for hidden
1460
    attention_dropout: float, default = 0.1
1461
        Dropout probability for the dropout op during multi-head attention.
1462
1463
1464
1465
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1466
    dropout_rng_name: str, default = 'dropout'
1467
1468
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1469
1470
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1471
1472
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1473
1474
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1475
1476
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1477
    mlp_activations: Sequence[str], default = ('relu', )
1478
        The sequence of activation functions to apply after the first linear transformation.
1479
1480
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1481
1482
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1483
    bias_init: Initializer, default = flax.linen.initializers.zeros
1484
1485
1486
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1487
    apply_residual_connection_post_layernorm: bool, default = False
1488
        If set to True, residual connections are taken from the output
1489
1490
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1491
        If set to True, layer normalization is applied on the output side,
1492
1493
1494
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1495
1496
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1497
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1498
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1499
1500
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1501
    self_attn_mask_type: str, default = 'causal'
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1521
1522
1523
1524
1525
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1526
    enable_relative_embedding: bool, default = True
1527
        Whether to enable relative embedding as shifting of attention logits.
1528
    relative_embedding: flax.linen.Module, default = None
1529
        The module for relative embedding execution, only used when
1530
1531
1532
1533
1534
1535
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1536
1537
1538
1539
1540
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1541
    rotary_pos_emb_group_method: str, default = 'consecutive'
1542
1543
1544
1545
        Indicate the method to couple the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
        where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
        :math:`i + 1`.
1546
1547
1548
    low_rank_adaptation_scope: str, default = 'none'
        Indicate the scope to apply low rank adaptation. It should be one of
        ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
1549
        'exclude_output_proj', 'exclude_mlp']
1550
1551
1552
1553
1554
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
1555
        :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
1556
1557
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1558
1559
1560

    Optimization parameters
    -----------------------
1561
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1562
        The data type used to allocate the initial parameters.
1563
    drop_path: float, default = 0.0
1564
        When > 0.0, applies stochastic depth per sample in the main
1565
1566
        path of the residual block.
    fuse_qkv_params: bool, default = True
1567
        If set to True, `TransformerLayer` module exposes a single fused
1568
1569
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1570
    transpose_batch_sequence: bool, default = False
1571
        Indicate whether the input tensors were switched axis of batch
1572
1573
1574
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1575
        Indicate whether to scale attention logits.
1576
1577
1578
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1579
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1580
1581
1582
1583
1584
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1585
    num_gqa_groups: Optional[int] = None
1586
    layernorm_type: str = "layernorm"
1587
    layernorm_epsilon: float = 1e-6
1588
    zero_centered_gamma: bool = False
1589
1590
1591
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1592
1593
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1594
    dropout_rng_name: str = "dropout"
1595
1596
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
1597
    mlp_activations: Sequence[str] = ("relu",)
1598
1599
1600
1601
1602
1603
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1604
    self_attn_mask_type: str = "causal"
1605
    self_attn_bias_type: Optional[str] = None
1606
1607
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1608
1609
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1610
1611
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
1612
1613
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1614
1615
1616
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1617
    transpose_batch_sequence: bool = False
1618
    enable_sequence_parallel: bool = False
1619
1620
1621
1622
1623
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

    def __post_init__(self):
        if self.mha_kernel_init is None:
1624
            self.mha_kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
1625
        if self.mlp_kernel_init is None:
1626
1627
1628
            self.mlp_kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal"
            )
zlsh80826's avatar
zlsh80826 committed
1629
1630
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1631
1632
1633
        super().__post_init__()

    @nn.compact
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
    def __call__(
        self,
        inputs: Array,
        encoded: Array = None,
        attention_mask: Array = None,
        encoder_decoder_mask: Array = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ):
1644
1645
1646
1647
1648
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1649
        inputs: jax.numpy.ndarray
1650
            Input tensor.
1651
        encoded: jax.numpy.ndarray, default = None
1652
1653
1654
1655
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1656
1657
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1658
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1659
1660
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1661
            :attr:`True` means mask out the corresponding values.
1662
        deterministic: bool, default = False
1663
            Disable dropout layers if set to True.
1664
        decode: bool, default = False
1665
1666
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1667
        max_decode_length: bool, default = None
1668
1669
1670
1671
1672
1673
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1674
        outputs: jax.numpy.ndarray
1675
            Output tensors.
1676
        """
1677
1678
1679
        assert (
            self.layer_type in TransformerLayerType
        ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
1680

1681
1682
1683
1684
        assert self.hidden_size % self.num_attention_heads == 0, (
            "hidden_size should be multiples of num_attention_heads"
            f", but got {self.hidden_size=} and {self.num_attention_heads=}."
        )
1685

1686
1687
1688
        assert self.layer_type == TransformerLayerType.DECODER or (
            self.layer_type == TransformerLayerType.ENCODER and decode is False
        ), "decode should be False when layer_type == TransformerLayerType.ENCODER."
1689
1690
1691
1692
1693
1694

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1695
1696
1697
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

1698
1699
1700
            is_shared_seq = (
                self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq
            )
1701
1702
1703
1704
1705

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1706
1707
1708
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1709
1710
1711
1712
1713
1714
1715
1716
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_attention_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
1734
            mha_name = "attention"
1735
        else:
1736
            mha_name = "self_attention"
1737

1738
        inputs = with_sharding_constraint_by_logical_axes(
1739
1740
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1741

1742
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1743
1744
1745
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1746
1747
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1748
            num_gqa_groups=self.num_gqa_groups,
1749
            transpose_batch_sequence=self.transpose_batch_sequence,
1750
            enable_sequence_parallel=self.enable_sequence_parallel,
1751
            attention_dropout=self.attention_dropout,
1752
1753
1754
1755
1756
1757
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1758
            zero_centered_gamma=self.zero_centered_gamma,
1759
1760
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1761
            attn_mask_type=self.self_attn_mask_type,
1762
            attn_bias_type=self.self_attn_bias_type,
1763
1764
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1765
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1766
1767
1768
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1769
            fuse_qkv_params=self.fuse_qkv_params,
1770
1771
1772
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1773
1774
            name=mha_name,
        )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
1775
1776
1777
1778
1779

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1780
                assert -x_shape_len <= dims < x_shape_len
1781

1782
1783
1784
1785
1786
            return nn.Dropout(
                rate=self.hidden_dropout,
                broadcast_dims=self.hidden_dropout_dims,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1787

1788
        x = with_sharding_constraint_by_logical_axes(
1789
1790
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1791
        residual = with_sharding_constraint_by_logical_axes(
1792
1793
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1794

1795
1796
1797
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1798
1799
1800
1801
1802
            x = nn.Dropout(
                rate=self.drop_path,
                broadcast_dims=drop_path_shape,
                rng_collection=self.dropout_rng_name,
            )(x, deterministic=deterministic)
1803
1804
1805
1806
1807

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1808
1809
1810
1811
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
1812
1813
1814
            assert (
                encoded is not None
            ), "encoded is required when layer_type == TransformerLayerType.DECODER."
1815

1816
            x = with_sharding_constraint_by_logical_axes(
1817
1818
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1819

1820
1821
1822
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1823
1824
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1825
                num_gqa_groups=self.num_gqa_groups,
1826
                transpose_batch_sequence=self.transpose_batch_sequence,
1827
                enable_sequence_parallel=self.enable_sequence_parallel,
1828
                attention_dropout=self.attention_dropout,
1829
1830
1831
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1832
                zero_centered_gamma=self.zero_centered_gamma,
1833
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
1834
1835
1836
                input_layernorm=True,  # Must do LayerNorm before MHA.
                attn_mask_type="padding",
                attn_bias_type="no_bias",
1837
1838
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1839
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1840
1841
1842
                low_rank_adaptation_scope=self.low_rank_adaptation_scope,
                low_rank_adaptation_dim=self.low_rank_adaptation_dim,
                low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1843
1844
1845
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
1846
                fuse_qkv_params=self.fuse_qkv_params,
1847
1848
1849
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1850
1851
                name="encoder_decoder_attention",
            )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
1852
1853

            y = with_sharding_constraint_by_logical_axes(
1854
1855
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1856
            residual = with_sharding_constraint_by_logical_axes(
1857
1858
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
1859

1860
            y = hidden_dropout(y, deterministic)
1861
1862
1863
1864
1865

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

1866
1867
            mlp_input = y + residual

1868
        mlp_input = with_sharding_constraint_by_logical_axes(
1869
1870
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1871

1872
1873
        lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)

1874
1875
1876
1877
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
1878
            zero_centered_gamma=self.zero_centered_gamma,
1879
1880
1881
1882
1883
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
1884
1885
1886
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
1887
            dtype=self.dtype,
1888
1889
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
1890
            kernel_init=self.mlp_kernel_init,
1891
1892
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
1893
1894
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1895
1896
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
1897
1898
1899
            enable_low_rank_adaptation=lora_scope.mlp,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
1900
1901
1902
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
1903
            name="mlp",
1904
1905
1906
1907
1908
1909
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1910
        z = with_sharding_constraint_by_logical_axes(
1911
1912
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1913
        residual = with_sharding_constraint_by_logical_axes(
1914
1915
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
        )
1916

1917
1918
1919
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
1920
1921
1922
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
1923
1924
1925
        z = z + residual

        if self.output_layernorm:
1926
            z = with_sharding_constraint_by_logical_axes(
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)
            )
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.layernorm_epsilon,
                scale_axes=(W_NO_SHARD_AXES,),
                bias_axes=(W_NO_SHARD_AXES,),
                transpose_batch_sequence=self.transpose_batch_sequence,
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
1939
1940

        return z