transformer.py 81 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
19
from flax.linen.attention import combine_masks
20
21
22
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
23
from jax.ad_checkpoint import checkpoint_name
24
25
26

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
27
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
28
from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
29
from ..fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
57
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
58
59
60
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
61
62
63
64
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
65
66
67
68

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
70
71
72
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

73
74
    Parameters
    ----------
75
    rules: Sequence[Tuple[str, Union[str, None]]]
76
77
78
79
        the base Flax logical axis rules to extend.

    Returns
    -------
80
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
81
82
83
84
85
86
87
88
89
90
91
92
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
        assert len(item) == 2, \
            "The logical axis rule should be like (axis_name, mesh_axis_name)."
        key = item[0]
        val = item[1]
        assert isinstance(key, str), \
            f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (val is None), \
            f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
93
94
95
96
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
97
98

    extended_rules = [*rules]
99
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
100
101
102
        key = item[0]
        val = item[1]
        if key in rules_map:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
103
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
104
105
106
107
108
109
110
111
                f"The rule diverged between TE and given rule." \
                f"Axis:{key} map to {rules_map[key]} in the given" \
                f" rules, but {val} in TE's rules."
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class _UnfusedDotProductAttention(nn.Module):    # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    @nn.compact
    def __call__(self,
                 query: Array,
                 key: Array,
                 value: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 dropout_rng: Optional[PRNGKey] = None,
                 deterministic: bool = False) -> Array:
        assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
        batch_dim = 1 if self.transpose_batch_sequence else 0
        assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
            'q, k, v batch dims must match.')
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
        assert key.shape[-2] == value.shape[-2], 'k, v num_attention_heads must match.'
        assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
        is_gqa = (h_q != h_kv)

154
        if is_gqa:
155
156
157
158
159
160
161
162
163
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
                attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
            else:
                attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
164
        else:
165
166
167
168
169
170
171
            if is_gqa:
                attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
            else:
                attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)

        attn_weights = checkpoint_name(attn_weights, 'logits')

172
        if is_gqa:
173
174
175
176
177
178
179
180
181
182
183
184
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        attn_weights = with_sharding_constraint_by_logical_axes(
            attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
            fused_scale_factor = 1.
185
        else:
186
187
188
189
190
191
192
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
193
194
195
            # mask is ignored for no_mask and causal_mask
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
                mask = None
196
            if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
197
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
198
199
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
                if mask is not None:
200
201
202
203
                    return SoftmaxType.SCALED_MASKED, mask
                return SoftmaxType.SCALED, mask
            raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type="
                             "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}")
204

205
        softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask)
206
207
208
209
210

        attn_weights = Softmax(softmax_type=softmax_type,
                               scale_factor=fused_scale_factor)(attn_weights, mask,
                                                                bias).astype(self.dtype)

211
        if is_gqa:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

        if not deterministic and self.attention_dropout > 0.:
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
            multiplier = (keep.astype(attn_weights.dtype) /
                          jnp.asarray(keep_prob, dtype=self.dtype))
            attn_weights = attn_weights * multiplier

        if self.transpose_batch_sequence:
            if is_gqa:
                return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
            return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)

        if is_gqa:
            return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
        return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)


class _FusedDotProductAttention(nn.Module):    # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False

    @nn.compact
    def __call__(self,
                 query: Array,
                 key: Array,
                 value: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 dropout_rng: Optional[PRNGKey] = None,
                 deterministic: bool = False) -> Array:

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.qkv_layout == QKVLayout.BS3HD:
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
272
273
274
275
276
277
278
279
280
            x = fused_attn_qkvpacked(qkv_packed,
                                     bias,
                                     mask,
                                     seed,
                                     attn_mask_type=self.attn_mask_type,
                                     attn_bias_type=self.attn_bias_type,
                                     scaling_factor=scale_factor,
                                     dropout_probability=self.attention_dropout,
                                     is_training=not deterministic)
281
282
283
284
285
286
287
288
289
290
        elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
291
292
293
294
295
296
297
298
299
300
            x = fused_attn_kvpacked(query,
                                    kv_packed,
                                    bias,
                                    mask,
                                    seed,
                                    attn_mask_type=self.attn_mask_type,
                                    attn_bias_type=self.attn_bias_type,
                                    scaling_factor=scale_factor,
                                    dropout_probability=self.attention_dropout,
                                    is_training=not deterministic)
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
            x = fused_attn(query,
                           key,
                           value,
                           bias,
                           mask,
                           seed,
                           attn_mask_type=self.attn_mask_type,
                           attn_bias_type=self.attn_bias_type,
                           scaling_factor=scale_factor,
                           dropout_probability=self.attention_dropout,
                           is_training=not deterministic)
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

        return x


class DotProductAttention(nn.Module):    # pylint: disable=too-few-public-methods
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

383
    attn_bias_type: Optional[str], default = None
384
        Type of the attention bias passed in the attention.
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).

    Optimization parameters
    -----------------------
    dtype: jax.numpy.dtype, default = jax.numpy.float32
        The data type used to allocate the initial parameters.
    """
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
    attention_dropout: float = 0.
    attn_mask_type: AttnMaskType = 'causal'
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
    dropout_rng_name: str = 'dropout'
    float32_logits: bool = False
    qkv_layout: str = 'bshd_bshd_bshd'
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    @nn.compact
    def __call__(self,
                 query: Array,
                 key: Array,
                 value: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 deterministic: bool = False) -> Array:
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
460
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """

        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]

        has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
                                                               attn_bias_type, attn_mask_type,
                                                               self.attention_dropout,
                                                               self.num_attention_heads,
                                                               self.num_gqa_groups, seqlen_q,
                                                               seqlen_kv, self.head_dim)

        use_fused_attn = (enable_fused_attn and has_fused_attn_kernel)

        if enable_fused_attn and not has_fused_attn_kernel:
            warnings.warn("Fused attention is not enabled because there is no available kernel.\n"
                          "Fall back to the unfused attention.\n"
                          "Please try to update the cuDNN and TE to the latest version.\n"
                          f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                          f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
                          f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n")

        dropout_rng = None
        if not deterministic and self.attention_dropout > 0.:
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(self.head_dim)
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(query, [1, 2], axis=-3)
                query, key, value = map(functools.partial(jnp.squeeze, axis=-3),
                                        [query, key, value])
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD

            x = _UnfusedDotProductAttention(attention_dropout=self.attention_dropout,
                                            attn_mask_type=attn_mask_type,
                                            attn_bias_type=attn_bias_type,
                                            dtype=self.dtype,
                                            float32_logits=self.float32_logits,
                                            scale_factor=scale_factor,
                                            transpose_batch_sequence=self.transpose_batch_sequence)(
                                                query,
                                                key,
                                                value,
                                                mask,
                                                bias,
                                                dropout_rng=dropout_rng,
                                                deterministic=deterministic)
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
            )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
560

561
        return x
562
563


564
565
566
567
def rotary_pos_emb(x: Array,
                   windows: Tuple[int, int],
                   transpose_batch_sequence: bool,
                   group_method: str = 'consecutive'):
568
569
570
    """
    Rotary Positional Embedding
    x should be in shape of
571
572
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
573
    """
574
575
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
576
577
578
    min_window = windows[0]
    max_window = windows[1]

579
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
580
581
582
583
584
585
586
587
588
    time_scales = min_window * (max_window / min_window)**fraction
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

589
590
591
592
593
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
594

595
596
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
597

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
        x1, x2 = jnp.split(x, 2, axis=-1)
        part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(x.dtype)

        output = jnp.concatenate([part_1, part_2], axis=-1)
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
        canonicalized_gm = gm.lower().strip().replace('-', '').replace('_', '')
        assert canonicalized_gm in ['consecutive', 'alternate'], \
            f"Invalid relative positional embedding group method. " \
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

    if group_method == 'alternate':
        return alternate_impl()
    return consecutive_impl()
638
639


640
class MultiHeadAttention(nn.Module):    # pylint: disable=too-few-public-methods
641
642
643
644
645
646
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
647
    head_dim: int
648
        The hidden dimension of each attention head.
649
650
651
652
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
653
654
655
656
657
658
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
659
660
661
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        This parameter specifies the type of attention mask to be applied during the softmax
        operation.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

681
682
683
684
685
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
686
    dropout_rng_name: str, default = 'dropout'
687
688
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
689
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
690
        Indicate the type of layer normalization.
691
    layernorm_epsilon: float, default = 1e-6
692
        A value added to the denominator of layer normalization for numerical stability.
693
    zero_centered_gamma: bool, default = False
694
695
696
697
698
699
700
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
701
702
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
703
        Used for initializing the QKV and output projection weights.
704
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
705
    use_bias: bool, default = False
706
        Indicate whether or not to enable bias shifting for QKV and output projections.
707
        If set to False, the layer will not learn additive biases.
708
    bias_init: Initializer, default = flax.linen.initializers.zeros
709
710
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
711
712
713
714
715
716
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
717
718
719
720
721
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
722
723
724
725
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
726
727
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
728
729
730
731
732
733
734
735
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
736
737
738

    Optimization parameters
    -----------------------
739
    dtype: jax.numpy.dtype, default = jax.numpy.float32
740
        The data type used to allocate the initial parameters.
741
    fuse_qkv_params: bool, default = True
742
        If set to True, this module exposes a single fused
743
744
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
745
    transpose_batch_sequence: bool, default = True
746
        Indicate whether the input tensors were switched axis of batch
747
748
749
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
750
        Indicate whether to scale attention logits.
751
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
752
        else :math:`Q*K`
753
754
755
756
757
758
759
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
760
761
762
    """

    head_dim: int
763
764
765
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
    attention_dropout: float = 0.
766
    dropout_rng_name: str = 'dropout'
767
    input_layernorm: bool = True
768
769
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
770
    return_layernorm_output: bool = False
771
    zero_centered_gamma: bool = False
772
773
774
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
775
    attn_mask_type: str = 'causal'
776
    attn_bias_type: Optional[str] = None
777
778
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
779
    rotary_pos_emb_group_method: str = 'consecutive'
780
    dtype: DType = jnp.float32
781
    fuse_qkv_params: bool = True
782
    transpose_batch_sequence: bool = True
783
    enable_sequence_parallel: bool = False
784
785
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
786
787
788
789
790
791
792
793
    float32_logits: bool = False

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
794
795

    def __post_init__(self):
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
                f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
                f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
                DeprecationWarning)
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
                f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")

820
821
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
zlsh80826's avatar
zlsh80826 committed
822
        if self.num_gqa_groups is None:
823
            self.num_gqa_groups = self.num_attention_heads
824
825
826
827
828
829
830
831
832
833
834
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs_q: Array,
                 inputs_kv: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> Array:
835
836
837
838
839
840
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
841
        inputs_q: jax.numpy.ndarray
842
            Input tensor for query projection.
843
        inputs_kv: jax.numpy.ndarray
844
            Input tensor for key/value projection.
845
846
847
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
848
            Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
849
850
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
851
        *
852
        decode: bool, default = False
853
            Indicate whether to prepare and use an autoregressive cache.
854
        deterministic: bool, default = False
855
856
857
858
            Disable dropout layers if set to True.

        Returns
        -------
859
        outputs: jax.numpy.ndarray
860
861
            Output tensors.
        """
862
863

        def query_init(*args):
864
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

897
898
899
900
901
902
903
904
905
906
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

907
908
909
910
        is_self_attn = (inputs_q is inputs_kv)
        is_gqa = (self.num_attention_heads != self.num_gqa_groups)
        is_qkvpack = (is_self_attn and not is_gqa)

911
912
913
914
915
916
        inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
            self.enable_sequence_parallel), HIDDEN_AXES)
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

917
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
918
            if is_qkvpack:
919
                qkv_proj, ln_out = LayerNormDenseGeneral(
920
                    enable_layernorm=self.input_layernorm,
921
                    layernorm_type=self.layernorm_type,
922
                    zero_centered_gamma=self.zero_centered_gamma,
923
924
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
925
                    features=(3, self.num_attention_heads * self.head_dim),
926
                    transpose_batch_sequence=self.transpose_batch_sequence,
927
                    return_layernorm_output=self.return_layernorm_output,
928
929
930
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
931
932
933
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
934
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
935
936
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
937
938
                    name='qkv',
                    dtype=self.dtype)(inputs_q)
939
                qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
940
                qkv_layout = QKVLayout.BS3HD
941
942
            else:
                query, ln_out = LayerNormDenseGeneral(
943
                    enable_layernorm=self.input_layernorm,
944
                    layernorm_type=self.layernorm_type,
945
                    zero_centered_gamma=self.zero_centered_gamma,
946
947
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
948
                    features=self.num_attention_heads * self.head_dim,
949
                    transpose_batch_sequence=self.transpose_batch_sequence,
950
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
951
952
953
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
954
955
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
956
                    bias_axes=(W_TP_AXES,),
957
958
                    dtype=self.dtype,
                    kernel_init=query_init,
959
960
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
961
                    name='query')(inputs_q)
zlsh80826's avatar
zlsh80826 committed
962
963
964
965
966

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

967
                kv_proj = DenseGeneral(axis=-1,
zlsh80826's avatar
zlsh80826 committed
968
                                       features=(2, self.num_gqa_groups * self.head_dim),
969
                                       transpose_batch_sequence=self.transpose_batch_sequence,
970
                                       kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
971
972
973
                                       kernel_init=kv_init,
                                       use_bias=self.use_bias,
                                       bias_init=self.bias_init,
974
                                       bias_axes=(W_JOINED_AXES, W_TP_AXES),
975
976
                                       name='kv',
                                       dtype=self.dtype)(inputs_kv)
977
                kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
978
                qkv_layout = QKVLayout.BSHD_BS2HD
979
980
981
982
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
983
                features=self.num_gqa_groups * self.head_dim,
984
                transpose_batch_sequence=self.transpose_batch_sequence,
985
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
986
987
                use_bias=self.use_bias,
                bias_init=self.bias_init,
988
                bias_axes=(W_TP_AXES,),
989
990
                dtype=self.dtype)
            query, ln_out = LayerNormDenseGeneral(
991
                enable_layernorm=self.input_layernorm,
992
                layernorm_type=self.layernorm_type,
993
                zero_centered_gamma=self.zero_centered_gamma,
994
995
                epsilon=self.layernorm_epsilon,
                axis=-1,
996
                features=self.num_attention_heads * self.head_dim,
997
998
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
999
1000
1001
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
1002
1003
                use_bias=self.use_bias,
                bias_init=self.bias_init,
1004
                bias_axes=(W_TP_AXES,),
1005
1006
                dtype=self.dtype,
                kernel_init=query_init,
1007
1008
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
1009
1010
                name='query')(inputs_q)

1011
            if is_self_attn:
1012
1013
1014
1015
1016
                assert ln_out is not None
                inputs_kv = ln_out

            key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
1017
1018
1019
1020
            query = checkpoint_name(query, 'query_proj')
            key = checkpoint_name(key, 'key_proj')
            value = checkpoint_name(value, 'value_proj')
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1021

1022
        if self.enable_rotary_pos_emb:
1023
1024
1025
1026
1027
1028
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1029

1030
1031
1032
1033
            # No changes to memory layout, should trigger bicast only (Ideally no Perf impact)
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

1034
            query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
1035
1036
1037
                                   self.transpose_batch_sequence, self.rotary_pos_emb_group_method)
            key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence,
                                 self.rotary_pos_emb_group_method)
1038
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1039

1040
1041
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1042
1043
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1044
1045

        if decode:
1046
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1047
1048
            is_initialized = self.has_variable('cache', 'cached_key')

Ming-Xu Huang's avatar
Ming-Xu Huang committed
1049
1050
            cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
1051
1052
1053
1054
                                         value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1055
                if self.transpose_batch_sequence:
1056
1057
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1058
1059
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1060
1061
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1062
                    one_hot_indices_shape = (1, length, 1, 1)
1063
1064
1065
1066
1067
1068
1069
1070
1071

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        f"expected query shape {expected_shape} instead got {query.shape}.")

                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1072
1073
1074
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1075
1076
1077
1078
1079
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1080
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
1081
1082

                if bias is not None:
1083
1084
                    dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim,
                                                       in_axes=(None, 0, None, None))
1085
1086
1087
                    bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
                                                       jnp.reshape(cur_index, (-1)), 1, -2)

1088
1089
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0

1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
            qkv_proj = qkv_proj.reshape(*qkv_proj.shape[:2], 3, self.num_attention_heads,
                                        self.head_dim)
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1108
        else:
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

        x = DotProductAttention(head_dim=self.head_dim,
                                num_attention_heads=self.num_attention_heads,
                                num_gqa_groups=self.num_gqa_groups,
                                attn_mask_type=self.attn_mask_type,
                                attn_bias_type=self.attn_bias_type,
                                attention_dropout=self.attention_dropout,
                                dtype=self.dtype,
                                dropout_rng_name=self.dropout_rng_name,
                                float32_logits=self.float32_logits,
                                qkv_layout=qkv_layout.name,
                                scale_factor=scale_factor,
                                transpose_batch_sequence=self.transpose_batch_sequence)(
                                    *dpa_args, mask, bias, deterministic=deterministic)
1132
1133
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1134
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1135
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1136
1137
1138
1139
1140

        out = DenseGeneral(features=inputs_q.shape[-1],
                           transpose_batch_sequence=self.transpose_batch_sequence,
                           axis=-1,
                           kernel_init=self.kernel_init,
1141
                           kernel_axes=(W_TP_AXES, W_FSDP_AXES),
1142
1143
                           use_bias=self.use_bias,
                           bias_init=self.bias_init,
1144
                           bias_axes=(W_NO_SHARD_AXES,),
1145
1146
                           dtype=self.dtype,
                           name='out')(x)
1147
        out = checkpoint_name(out, 'out_proj')
1148
1149

        return out, ln_out
1150
1151


1152
class RelativePositionBiases(nn.Module):    # pylint: disable=too-few-public-methods
1153
1154
1155
1156
1157
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1158
    num_buckets: int
1159
        The number of buckets to bucket distances between key and query positions into.
1160
    max_distance: int
1161
        The maximum distance before everything is lumped into the last
1162
        distance bucket.
1163
    num_attention_heads: int
1164
        Number of attention heads in the transformer layer.
1165
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1166
        Used for initializing relative embedding tables.
1167
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1168
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1169
1170
1171

    Optimization parameters
    -----------------------
1172
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1173
        The data type used to allocate the initial parameters.
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
    """
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
    embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets')
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1189
        q_seqlen: int
1190
            The sequence length of query.
1191
        k_seqlen: int
1192
            The sequence length of key.
1193
        bidirectional: bool, default = True
1194
            Indicate whether to allow positive memory-query relative position
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) /
            np.log(self.max_distance / rpb_max_exact) *
            (rpb_num_buckets - rpb_max_exact)).astype(np.int32)
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
            'rel_embedding',
            self.embedding_init, (self.num_attention_heads, self.num_buckets),
            jnp.float32,
            axes=self.embedding_axes)

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

        values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot,
                                 (((1,), (0,)), ((), ())))
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1255
1256
1257
1258
    ENCODER = "encoder"
    DECODER = "decoder"


1259
class TransformerLayer(nn.Module):    # pylint: disable=too-few-public-methods
1260
1261
1262
1263
1264
1265
1266
1267
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1268
        The hidden size of each input sample.
1269
    mlp_hidden_size: int, default = 2048
1270
        Intermediate size to which input samples are projected.
1271
    num_attention_heads: int, default = 8
1272
        Number of attention heads in the transformer layer.
1273
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1274
1275
1276
1277
1278
1279
1280
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1281
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1282
        Indicate the type of layer normalization.
1283
    layernorm_epsilon: float, default = 1e-6
1284
        A value added to the denominator of layer normalization for numerical stability.
1285
    zero_centered_gamma: bool, default = False
1286
1287
1288
1289
1290
1291
1292
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1293
    hidden_dropout: float, default = 0.1
1294
        Dropout probability for the dropout op after FC2 layer.
1295
    hidden_dropout_dims: Sequence[int], default = ()
1296
        Dimensions that will share the same dropout mask for hidden
1297
    attention_dropout: float, default = 0.1
1298
        Dropout probability for the dropout op during multi-head attention.
1299
1300
1301
1302
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1303
    dropout_rng_name: str, default = 'dropout'
1304
1305
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1306
1307
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1308
1309
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1310
1311
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1312
1313
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1314
    mlp_activations: Sequence[str], default = ('relu', )
1315
        The sequence of activation functions to apply after the first linear transformation.
1316
1317
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1318
1319
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1320
    bias_init: Initializer, default = flax.linen.initializers.zeros
1321
1322
1323
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1324
    apply_residual_connection_post_layernorm: bool, default = False
1325
        If set to True, residual connections are taken from the output
1326
1327
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1328
        If set to True, layer normalization is applied on the output side,
1329
1330
1331
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1332
1333
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1334
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1335
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1336
1337
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1338
    self_attn_mask_type: str, default = 'causal'
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
        This parameter specifies the type of attention mask to be applied during the softmax
        operation in the self attention.
        Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}

        Each described below:

        * no_mask: No attention mask is applied. This means the self attention will consider the
          full sequence without any restrictions.
        * padding: Indicates the presence of padding at the end of each sequence.
          Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
          :attr:`__call__` method to specify the padding positions.
        * causal: An upper triangular mask is applied to the softmax inputs,
          ensuring that the prediction for a certain position is only dependent on known outputs
          from positions before it.
        * causal_padding / padding_causal: A combination of both causal and padding masks.
          Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.

        .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.

1358
1359
1360
1361
1362
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1363
    enable_relative_embedding: bool, default = True
1364
        Whether to enable relative embedding as shifting of attention logits.
1365
    relative_embedding: flax.linen.Module, default = None
1366
        The module for relative embedding execution, only used when
1367
1368
1369
1370
1371
1372
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1373
1374
1375
1376
1377
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1378
1379
1380
1381
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
1382
1383
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1384
1385
1386

    Optimization parameters
    -----------------------
1387
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1388
        The data type used to allocate the initial parameters.
1389
    drop_path: float, default = 0.0
1390
        When > 0.0, applies stochastic depth per sample in the main
1391
1392
        path of the residual block.
    fuse_qkv_params: bool, default = True
1393
        If set to True, `TransformerLayer` module exposes a single fused
1394
1395
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1396
    transpose_batch_sequence: bool, default = False
1397
        Indicate whether the input tensors were switched axis of batch
1398
1399
1400
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1401
        Indicate whether to scale attention logits.
1402
1403
1404
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1405
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1406
1407
1408
1409
1410
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1411
    num_gqa_groups: Optional[int] = None
1412
1413
    layernorm_type: str = 'layernorm'
    layernorm_epsilon: float = 1e-6
1414
    zero_centered_gamma: bool = False
1415
1416
1417
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1418
1419
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
    dropout_rng_name: str = 'dropout'
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
    mlp_activations: Sequence[str] = ('relu',)
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1430
    self_attn_mask_type: str = 'causal'
1431
    self_attn_bias_type: Optional[str] = None
1432
1433
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1434
1435
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1436
    rotary_pos_emb_group_method: str = 'consecutive'
1437
1438
1439
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1440
    transpose_batch_sequence: bool = False
1441
    enable_sequence_parallel: bool = False
1442
1443
1444
1445
1446
1447
1448
1449
1450
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

    def __post_init__(self):
        if self.mha_kernel_init is None:
            self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
        if self.mlp_kernel_init is None:
            self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
                                                                    'truncated_normal')
zlsh80826's avatar
zlsh80826 committed
1451
1452
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs: Array,
                 encoded: Array = None,
                 attention_mask: Array = None,
                 encoder_decoder_mask: Array = None,
                 deterministic: bool = False,
                 decode: bool = False,
                 max_decode_length: bool = None):
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1469
        inputs: jax.numpy.ndarray
1470
            Input tensor.
1471
        encoded: jax.numpy.ndarray, default = None
1472
1473
1474
1475
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1476
1477
            :attr:`True` means mask out the corresponding values.
            Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
1478
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1479
1480
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
1481
            :attr:`True` means mask out the corresponding values.
1482
        deterministic: bool, default = False
1483
            Disable dropout layers if set to True.
1484
        decode: bool, default = False
1485
1486
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1487
        max_decode_length: bool, default = None
1488
1489
1490
1491
1492
1493
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1494
        outputs: jax.numpy.ndarray
1495
            Output tensors.
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
        """
        assert self.layer_type in TransformerLayerType, \
                "layer_type should be one of TransformerLayerType" \
                f", but got {self.layer_type}."

        assert self.hidden_size % self.num_attention_heads == 0, \
                "hidden_size should be multiples of num_attention_heads" \
                f", but got {self.hidden_size=} and {self.num_attention_heads=}."

        assert self.layer_type == TransformerLayerType.DECODER or \
              (self.layer_type == TransformerLayerType.ENCODER and decode is False), \
               "decode should be False when layer_type == TransformerLayerType.ENCODER."

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

            is_shared_seq = self.enable_sequence_parallel if is_shared_seq is None \
                            else is_shared_seq

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
                rel_emb = RelativePositionBiases(num_buckets=32,
                                                 max_distance=128,
                                                 num_attention_heads=self.num_attention_heads,
                                                 dtype=self.dtype,
                                                 embedding_init=nn.initializers.variance_scaling(
                                                     1.0, 'fan_avg', 'uniform'),
                                                 name='relpos_bias')
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
            mha_name = 'attention'
        else:
            mha_name = 'self_attention'

1555
1556
        inputs = with_sharding_constraint_by_logical_axes(
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1557

1558
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1559
1560
1561
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1562
1563
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1564
            num_gqa_groups=self.num_gqa_groups,
1565
            transpose_batch_sequence=self.transpose_batch_sequence,
1566
            enable_sequence_parallel=self.enable_sequence_parallel,
1567
            attention_dropout=self.attention_dropout,
1568
1569
1570
1571
1572
1573
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1574
            zero_centered_gamma=self.zero_centered_gamma,
1575
1576
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1577
            attn_mask_type=self.self_attn_mask_type,
1578
            attn_bias_type=self.self_attn_bias_type,
1579
1580
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1581
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1582
            fuse_qkv_params=self.fuse_qkv_params,
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            name=mha_name)(inputs,
                           inputs,
                           attention_mask,
                           attn_bias,
                           deterministic=deterministic,
                           decode=decode)

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1597
                assert -x_shape_len <= dims < x_shape_len
1598
1599

            return nn.Dropout(rate=self.hidden_dropout,
1600
1601
                              broadcast_dims=self.hidden_dropout_dims,
                              rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1602

1603
1604
1605
1606
1607
        x = with_sharding_constraint_by_logical_axes(
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
        residual = with_sharding_constraint_by_logical_axes(
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1608
1609
1610
1611
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
            x = nn.Dropout(rate=self.drop_path,
1612
1613
                           broadcast_dims=drop_path_shape,
                           rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1614
1615
1616
1617
1618

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1619
1620
1621
1622
1623
1624
1625
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
            assert encoded is not None, \
                "encoded is required when layer_type == TransformerLayerType.DECODER."

1626
1627
1628
            x = with_sharding_constraint_by_logical_axes(
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1629
1630
1631
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1632
1633
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1634
                num_gqa_groups=self.num_gqa_groups,
1635
                transpose_batch_sequence=self.transpose_batch_sequence,
1636
                enable_sequence_parallel=self.enable_sequence_parallel,
1637
                attention_dropout=self.attention_dropout,
1638
1639
1640
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1641
                zero_centered_gamma=self.zero_centered_gamma,
1642
1643
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
                input_layernorm=True,    # Must do LayerNorm before MHA.
1644
                attn_mask_type='padding',
1645
                attn_bias_type='no_bias',
1646
1647
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1648
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1649
1650
1651
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
1652
                fuse_qkv_params=self.fuse_qkv_params,
1653
1654
1655
1656
1657
1658
1659
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
                name='encoder_decoder_attention')(x,
                                                  encoded,
                                                  encoder_decoder_mask,
                                                  deterministic=deterministic)
1660
1661
1662
1663
1664
1665

            y = with_sharding_constraint_by_logical_axes(
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
            residual = with_sharding_constraint_by_logical_axes(
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1666
            y = hidden_dropout(y, deterministic)
1667
1668
1669
1670
1671

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

1672
1673
            mlp_input = y + residual

1674
1675
        mlp_input = with_sharding_constraint_by_logical_axes(
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1676

1677
1678
1679
1680
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
1681
            zero_centered_gamma=self.zero_centered_gamma,
1682
1683
1684
1685
1686
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
1687
1688
1689
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
1690
            dtype=self.dtype,
1691
1692
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
1693
            kernel_init=self.mlp_kernel_init,
1694
1695
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
1696
1697
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1698
1699
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
1700
1701
1702
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
1703
1704
1705
1706
1707
1708
1709
            name='mlp',
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1710
1711
1712
1713
1714
        z = with_sharding_constraint_by_logical_axes(
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
        residual = with_sharding_constraint_by_logical_axes(
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1715
1716
1717
1718
1719
1720
1721
1722
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
            z = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
        z = z + residual

        if self.output_layernorm:
1723
1724
            z = with_sharding_constraint_by_logical_axes(
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1725
            z = LayerNorm(layernorm_type=self.layernorm_type,
1726
1727
                          zero_centered_gamma=self.zero_centered_gamma,
                          epsilon=self.layernorm_epsilon,
1728
1729
                          scale_axes=(W_NO_SHARD_AXES,),
                          bias_axes=(W_NO_SHARD_AXES,),
1730
1731
                          transpose_batch_sequence=self.transpose_batch_sequence,
                          dtype=self.dtype,
1732
                          name="output_layernorm")(z)
1733
1734

        return z