transformer.py 77.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
19
from flax.linen.attention import combine_masks
20
21
22
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
23
from jax.ad_checkpoint import checkpoint_name
24
25
26

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
27
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
28
29
from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn
30
from ..softmax import SoftmaxType
31
32
33
34
35
36
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
57
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
58
59
60
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
61
62
63
64
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
65
66
67
68

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
70
71
72
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

73
74
    Parameters
    ----------
75
    rules: Sequence[Tuple[str, Union[str, None]]]
76
77
78
79
        the base Flax logical axis rules to extend.

    Returns
    -------
80
    extended_rules: Sequence[Tuple[str, Union[str, None]]]
81
82
83
84
85
86
87
88
89
90
91
92
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
        assert len(item) == 2, \
            "The logical axis rule should be like (axis_name, mesh_axis_name)."
        key = item[0]
        val = item[1]
        assert isinstance(key, str), \
            f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (val is None), \
            f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
93
94
95
96
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
97
98

    extended_rules = [*rules]
99
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
100
101
102
        key = item[0]
        val = item[1]
        if key in rules_map:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
103
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
104
105
106
107
108
109
110
111
                f"The rule diverged between TE and given rule." \
                f"Axis:{key} map to {rules_map[key]} in the given" \
                f" rules, but {val} in TE's rules."
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class _UnfusedDotProductAttention(nn.Module):    # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    float32_logits: bool = False
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    @nn.compact
    def __call__(self,
                 query: Array,
                 key: Array,
                 value: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 dropout_rng: Optional[PRNGKey] = None,
                 deterministic: bool = False) -> Array:
        assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
        batch_dim = 1 if self.transpose_batch_sequence else 0
        assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
            'q, k, v batch dims must match.')
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
        assert key.shape[-2] == value.shape[-2], 'k, v num_attention_heads must match.'
        assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.float32_logits:
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)
        h_q, h_kv = query.shape[-2], key.shape[-2]
        # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
        # Therefore, we have to maintain two code paths.
        is_gqa = (h_q != h_kv)

154
        if is_gqa:
155
156
157
158
159
160
161
162
163
            assert (h_q % h_kv == 0) and (h_q >= h_kv)
            group_size = h_q // h_kv
            grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
            if is_gqa:
                attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
            else:
                attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
164
        else:
165
166
167
168
169
170
171
            if is_gqa:
                attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
            else:
                attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)

        attn_weights = checkpoint_name(attn_weights, 'logits')

172
        if is_gqa:
173
174
175
176
177
178
179
180
181
182
183
184
            b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
            attn_weights_without_groups_shape = (b, h * g, q, k)
            attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        attn_weights = with_sharding_constraint_by_logical_axes(
            attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))

        # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
        # In this case, the scale can not fused into the Softmax module.
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            attn_weights = attn_weights * scale_factor
            fused_scale_factor = 1.
185
        else:
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            # If not post_scale_bias, the scale can be fused into Softmax module
            fused_scale_factor = scale_factor
            if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
                attn_weights += bias

        def convert_to_softmax_type(attn_mask_type, mask):
            """Convert the attn_mask_type to SoftmaxType"""
            if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
                return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
                if mask is not None:
                    return SoftmaxType.SCALED_MASKED
                return SoftmaxType.SCALED
            raise ValueError(f"Unsupported {attn_mask_type=}, "
                             "supported attn_mask_type = {'causal', 'padding'}")

        softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)

        attn_weights = Softmax(softmax_type=softmax_type,
                               scale_factor=fused_scale_factor)(attn_weights, mask,
                                                                bias).astype(self.dtype)

208
        if is_gqa:
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
            attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

        if not deterministic and self.attention_dropout > 0.:
            keep_prob = 1.0 - self.attention_dropout
            dropout_shape = list(attn_weights.shape)
            # TODO(rewang): add attention dropout broadcast dimension arguments for users
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
            multiplier = (keep.astype(attn_weights.dtype) /
                          jnp.asarray(keep_prob, dtype=self.dtype))
            attn_weights = attn_weights * multiplier

        if self.transpose_batch_sequence:
            if is_gqa:
                return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
            return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)

        if is_gqa:
            return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
        return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)


class _FusedDotProductAttention(nn.Module):    # pylint: disable=too-few-public-methods
    attention_dropout: float = 0.
    attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
    attn_bias_type: Optional[AttnBiasType] = None
    dtype: DType = jnp.float32
    qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = False

    @nn.compact
    def __call__(self,
                 query: Array,
                 key: Array,
                 value: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 dropout_rng: Optional[PRNGKey] = None,
                 deterministic: bool = False) -> Array:

        seed = None
        if dropout_rng is not None:
            seed = jax.random.split(dropout_rng, num_of_devices())

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(query.shape[-1])
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if self.qkv_layout == QKVLayout.BS3HD:
            """qkvpacked format, treat
            query: qkvpacked tensor, shape = [..., 3, h, d]
            key: ignore
            value: ignore
            """
            qkv_packed = query
            if self.transpose_batch_sequence:
                qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
            x = self_fused_attn(qkv_packed,
                                bias,
                                mask,
                                seed,
                                attn_mask_type=self.attn_mask_type,
                                attn_bias_type=self.attn_bias_type,
                                scaling_factor=scale_factor,
                                dropout_probability=self.attention_dropout,
                                is_training=not deterministic)
        elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
            """kvpacked format, treat
            query: query tensor, shape = [..., h, d]
            key: kvpacked tensor, shape = [..., 2, h, d]
            value: ignore
            """
            kv_packed = key
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
            x = cross_fused_attn(query,
                                 kv_packed,
                                 bias,
                                 mask,
                                 seed,
                                 attn_mask_type=self.attn_mask_type,
                                 attn_bias_type=self.attn_bias_type,
                                 scaling_factor=scale_factor,
                                 dropout_probability=self.attention_dropout,
                                 is_training=not deterministic)
        elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            if self.transpose_batch_sequence:
                query = query.transpose([1, 0, 2, 3])
                key = key.transpose([1, 0, 2, 3])
                value = value.transpose([1, 0, 2, 3])
            x = fused_attn(query,
                           key,
                           value,
                           bias,
                           mask,
                           seed,
                           attn_mask_type=self.attn_mask_type,
                           attn_bias_type=self.attn_bias_type,
                           scaling_factor=scale_factor,
                           dropout_probability=self.attention_dropout,
                           is_training=not deterministic)
        else:
            raise ValueError(f"Unsupported {self.qkv_layout=}.")

        if self.transpose_batch_sequence:
            x = x.transpose([1, 0, 2, 3])

        return x


class DotProductAttention(nn.Module):    # pylint: disable=too-few-public-methods
    r"""
    Dot Product Attention (DPA). Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

    .. note::
        The DotProductAttention module supports two backends: the unfused and the fused attention
        mechanisms. The unfused attention is implemented using JAX native operations, providing
        broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused
        attention
        <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for
        higher performance and lower memory usage on the supported hardwares.
        Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
        variable:

        * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default).
        * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention
          kernel is not available on the system, a warning will be issued, and the module will
          automatically fall back to the unfused backend.

    Parameters
    ----------
    head_dim: int
        The hidden dimension of each attention head.
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
        Type of the attention mask passed into softmax operation in the self attention.
        Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
        Introduced in v0.10.0.
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used.
    dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    qkv_layout: str, default = 'bshd_bshd_bshd'
        Specifies the dimensional layout format for the query, key, and value tensors in __call__().
        It indicates how the inputs are processed.
        Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where

        * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
          key and value arguments in :attr:`__call__()` are ignored in this layout.
        * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
          tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
        * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].

        Explanation of denotations:

        * b: batch size
        * s: seqeuence length
        * h: num_attention_heads or num_gqa_groups
        * d: head dimension

    scale_factor: Optional[float], default = None
        Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
        to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
        need to apply scale on query, which is to set :attr:`scale_factor=1.`.
    transpose_batch_sequence: bool, default = True
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).

    Optimization parameters
    -----------------------
    dtype: jax.numpy.dtype, default = jax.numpy.float32
        The data type used to allocate the initial parameters.
    """
    head_dim: int
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
    attention_dropout: float = 0.
    attn_mask_type: AttnMaskType = 'causal'
    attn_bias_type: AttnBiasType = None
    dtype: DType = jnp.float32
    dropout_rng_name: str = 'dropout'
    float32_logits: bool = False
    qkv_layout: str = 'bshd_bshd_bshd'
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    @nn.compact
    def __call__(self,
                 query: Array,
                 key: Array,
                 value: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 deterministic: bool = False) -> Array:
        """
        Parameters
        ----------
        query: jax.numpy.ndarray
            The details of query tensor representation is described in :attr:`qkv_layout`.
        key: jax.numpy.ndarrary
            The details of kery tensor representation is described in :attr:`qkv_layout`.
        value: jax.numpy.ndarrary
            The details of value tensor representation is described in :attr:`qkv_layout`.
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means to mask out the corresponding values.
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift attention softmax input.
        *:
            Below parameters are keyword only
        deterministic: bool, default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs: jax.numpy.ndarray
            Output tensors.
        """

        # For internal API, we use enum to maintain
        if self.attn_bias_type is None:
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
        else:
            attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
        qkv_layout = QKVLayout[self.qkv_layout.upper()]
        del self.attn_bias_type, self.attn_mask_type, self.qkv_layout

        if attn_bias_type == AttnBiasType.NO_BIAS:
            assert bias is None
        else:
            assert bias is not None

        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        seqlen_q = query.shape[sequence_dim]
        if qkv_layout == QKVLayout.BS3HD:
            seqlen_kv = seqlen_q
        else:
            seqlen_kv = key.shape[sequence_dim]

        has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
                                                               attn_bias_type, attn_mask_type,
                                                               self.attention_dropout,
                                                               self.num_attention_heads,
                                                               self.num_gqa_groups, seqlen_q,
                                                               seqlen_kv, self.head_dim)

        use_fused_attn = (enable_fused_attn and has_fused_attn_kernel)

        if enable_fused_attn and not has_fused_attn_kernel:
            warnings.warn("Fused attention is not enabled because there is no available kernel.\n"
                          "Fall back to the unfused attention.\n"
                          "Please try to update the cuDNN and TE to the latest version.\n"
                          f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
                          f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
                          f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n")

        dropout_rng = None
        if not deterministic and self.attention_dropout > 0.:
            dropout_rng = self.make_rng(self.dropout_rng_name)

        if self.scale_factor is None:
            scale_factor = 1.0 / sqrt(self.head_dim)
        else:
            scale_factor = self.scale_factor
        del self.scale_factor

        if not use_fused_attn:
            # unfused attention only supports splitted query, key, value
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(query, [1, 2], axis=-3)
                query, key, value = map(functools.partial(jnp.squeeze, axis=-3),
                                        [query, key, value])
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(key, [1], axis=-3)
                key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD

            x = _UnfusedDotProductAttention(attention_dropout=self.attention_dropout,
                                            attn_mask_type=attn_mask_type,
                                            attn_bias_type=attn_bias_type,
                                            dtype=self.dtype,
                                            float32_logits=self.float32_logits,
                                            scale_factor=scale_factor,
                                            transpose_batch_sequence=self.transpose_batch_sequence)(
                                                query,
                                                key,
                                                value,
                                                mask,
                                                bias,
                                                dropout_rng=dropout_rng,
                                                deterministic=deterministic)
        else:
            x = _FusedDotProductAttention(
                attention_dropout=self.attention_dropout,
                attn_mask_type=attn_mask_type,
                attn_bias_type=attn_bias_type,
                dtype=self.dtype,
                scale_factor=scale_factor,
                transpose_batch_sequence=self.transpose_batch_sequence,
                qkv_layout=qkv_layout,
            )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
540

541
        return x
542
543


544
545
546
547
def rotary_pos_emb(x: Array,
                   windows: Tuple[int, int],
                   transpose_batch_sequence: bool,
                   group_method: str = 'consecutive'):
548
549
550
    """
    Rotary Positional Embedding
    x should be in shape of
551
552
    [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
553
    """
554
555
    hidden_dim = x.shape[-1]
    half_hidden_dim = hidden_dim // 2
556
557
558
    min_window = windows[0]
    max_window = windows[1]

559
    fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
560
561
562
563
564
565
566
567
568
    time_scales = min_window * (max_window / min_window)**fraction
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

569
570
571
572
573
    def generate_sin_cos(timescales):
        sinusoidal_positions = positions / timescales
        sin = jnp.sin(sinusoidal_positions)
        cos = jnp.cos(sinusoidal_positions)
        return sin, cos
574

575
576
    def alternate_impl():
        sin, cos = generate_sin_cos(time_scales)
577

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        x1, x2 = jnp.split(x, 2, axis=-1)
        part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
        part_2 = (x2 * cos + x1 * sin).astype(x.dtype)

        output = jnp.concatenate([part_1, part_2], axis=-1)
        return output

    def consecutive_impl():
        sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))

        x_shifted_left = jnp.roll(x, -1, axis=-1)
        x_shifted_right = jnp.roll(x, 1, axis=-1)
        x_shifted = jax.lax.select(
            jnp.tile(
                jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
                x.shape[:-1] + (1,),
            ),
            x_shifted_right,
            x_shifted_left,
        )

        sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)

        output = x * cos + x_shifted * sin * sign
        output = output.astype(x.dtype)
        return output

    def canonicalize_group_method(gm):
        canonicalized_gm = gm.lower().strip().replace('-', '').replace('_', '')
        assert canonicalized_gm in ['consecutive', 'alternate'], \
            f"Invalid relative positional embedding group method. " \
            f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."

        return canonicalized_gm

    group_method = canonicalize_group_method(group_method)

    if group_method == 'alternate':
        return alternate_impl()
    return consecutive_impl()
618
619


620
class MultiHeadAttention(nn.Module):    # pylint: disable=too-few-public-methods
621
622
623
624
625
626
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
627
    head_dim: int
628
        The hidden dimension of each attention head.
629
630
631
632
    num_attention_heads: int
        The number of attention heads.
    num_gqa_groups: int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
zlsh80826's avatar
zlsh80826 committed
633
634
635
636
637
638
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
639
640
641
642
643
644
645
646
647
648
649
    attention_dropout: float, default = 0.0
        Dropout probability for the dropout op after the softmax.
    attn_mask_type: str, default = 'causal'
        Type of the attention mask passed into softmax operation in the attention.
        Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
        Introduced in v0.10.0.
    attn_bias_type: Optional[str], default = None
        Type of the attention bias passed in the attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
650
    dropout_rng_name: str, default = 'dropout'
651
652
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
653
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
654
        Indicate the type of layer normalization.
655
    layernorm_epsilon: float, default = 1e-6
656
        A value added to the denominator of layer normalization for numerical stability.
657
    zero_centered_gamma: bool, default = False
658
659
660
661
662
663
664
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
665
666
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
667
        Used for initializing the QKV and output projection weights.
668
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
669
    use_bias: bool, default = False
670
        Indicate whether or not to enable bias shifting for QKV and output projections.
671
        If set to False, the layer will not learn additive biases.
672
    bias_init: Initializer, default = flax.linen.initializers.zeros
673
674
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
675
676
677
678
679
680
    input_layernorm: bool, default = True
        If set to False, layer normalization to the input is not applied.
    return_layernorm_output: bool, default = False
        If set to True, output of layernorm is returned from the forward together with the output
        of the linear transformation.
        Example use case: residual connection for transformer module is taken post layernorm.
681
682
683
684
685
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
686
687
688
689
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
690
691
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
692
693
694
695
696
697
698
699
    num_heads: int, default = None
        Deprecated. Please refer `num_attention_heads`.
    dropout_rate: float, default = None
        Deprecated. Please refer `attention_dropout`.
    output_layernorm: bool, default = None
        Deprecated. Please refer `input_layernorm`
    apply_residual_connection_post_layernorm: bool, default = None
        Deprecated. Please refer `return_layernorm_output`.
700
701
702

    Optimization parameters
    -----------------------
703
    dtype: jax.numpy.dtype, default = jax.numpy.float32
704
        The data type used to allocate the initial parameters.
705
    fuse_qkv_params: bool, default = True
706
        If set to True, this module exposes a single fused
707
708
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
709
    transpose_batch_sequence: bool, default = True
710
        Indicate whether the input tensors were switched axis of batch
711
712
713
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
714
        Indicate whether to scale attention logits.
715
        If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
716
        else :math:`Q*K`
717
718
719
720
721
722
723
    scaled_query_init: bool, default = True
        Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
    float32_logits: bool, default = False
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
    fuse_qkv: bool, default = None
        Deprecated. Please refer `fuse_qkv_params`
724
725
726
    """

    head_dim: int
727
728
729
    num_attention_heads: int
    num_gqa_groups: Optional[int] = None
    attention_dropout: float = 0.
730
    dropout_rng_name: str = 'dropout'
731
    input_layernorm: bool = True
732
733
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
734
    return_layernorm_output: bool = False
735
    zero_centered_gamma: bool = False
736
737
738
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
739
    attn_mask_type: str = 'causal'
740
    attn_bias_type: Optional[str] = None
741
742
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
743
    rotary_pos_emb_group_method: str = 'consecutive'
744
    dtype: DType = jnp.float32
745
    fuse_qkv_params: bool = True
746
    transpose_batch_sequence: bool = True
747
    enable_sequence_parallel: bool = False
748
749
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
750
751
752
753
754
755
756
757
    float32_logits: bool = False

    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None
758
759

    def __post_init__(self):
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
                f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
                f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
                DeprecationWarning)
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
                f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")

784
785
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
zlsh80826's avatar
zlsh80826 committed
786
        if self.num_gqa_groups is None:
787
            self.num_gqa_groups = self.num_attention_heads
788
789
790
791
792
793
794
795
796
797
798
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs_q: Array,
                 inputs_kv: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> Array:
799
800
801
802
803
804
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
805
        inputs_q: jax.numpy.ndarray
806
            Input tensor for query projection.
807
        inputs_kv: jax.numpy.ndarray
808
            Input tensor for key/value projection.
809
810
811
812
813
        mask: jax.numpy.ndarray, default = None
            Boolean tensor used to mask out the attention softmax input.
            :attr:`True` means mask out the corresponding values.
        bias: jax.numpy.ndarray, default = None
            A tensor used to shift the attention softmax input.
814
        *
815
        decode: bool, default = False
816
            Indicate whether to prepare and use an autoregressive cache.
817
        deterministic: bool, default = False
818
819
820
821
            Disable dropout layers if set to True.

        Returns
        -------
822
        outputs: jax.numpy.ndarray
823
824
            Output tensors.
        """
825
826

        def query_init(*args):
827
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

860
861
862
863
864
865
866
867
868
869
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

870
871
872
873
        is_self_attn = (inputs_q is inputs_kv)
        is_gqa = (self.num_attention_heads != self.num_gqa_groups)
        is_qkvpack = (is_self_attn and not is_gqa)

874
875
876
877
878
879
        inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
            self.enable_sequence_parallel), HIDDEN_AXES)
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

880
        if self.fuse_qkv_params:
zlsh80826's avatar
zlsh80826 committed
881
            if is_qkvpack:
882
                qkv_proj, ln_out = LayerNormDenseGeneral(
883
                    enable_layernorm=self.input_layernorm,
884
                    layernorm_type=self.layernorm_type,
885
                    zero_centered_gamma=self.zero_centered_gamma,
886
887
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
888
                    features=(3, self.num_attention_heads * self.head_dim),
889
                    transpose_batch_sequence=self.transpose_batch_sequence,
890
                    return_layernorm_output=self.return_layernorm_output,
891
892
893
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
894
895
896
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
897
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
898
899
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
900
901
                    name='qkv',
                    dtype=self.dtype)(inputs_q)
902
                qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
903
                qkv_layout = QKVLayout.BS3HD
904
905
            else:
                query, ln_out = LayerNormDenseGeneral(
906
                    enable_layernorm=self.input_layernorm,
907
                    layernorm_type=self.layernorm_type,
908
                    zero_centered_gamma=self.zero_centered_gamma,
909
910
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
911
                    features=self.num_attention_heads * self.head_dim,
912
                    transpose_batch_sequence=self.transpose_batch_sequence,
913
                    return_layernorm_output=(self.return_layernorm_output or is_self_attn),
914
915
916
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
917
918
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
919
                    bias_axes=(W_TP_AXES,),
920
921
                    dtype=self.dtype,
                    kernel_init=query_init,
922
923
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
924
                    name='query')(inputs_q)
zlsh80826's avatar
zlsh80826 committed
925
926
927
928
929

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

930
                kv_proj = DenseGeneral(axis=-1,
zlsh80826's avatar
zlsh80826 committed
931
                                       features=(2, self.num_gqa_groups * self.head_dim),
932
                                       transpose_batch_sequence=self.transpose_batch_sequence,
933
                                       kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
934
935
936
                                       kernel_init=kv_init,
                                       use_bias=self.use_bias,
                                       bias_init=self.bias_init,
937
                                       bias_axes=(W_JOINED_AXES, W_TP_AXES),
938
939
                                       name='kv',
                                       dtype=self.dtype)(inputs_kv)
940
                kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
941
                qkv_layout = QKVLayout.BSHD_BS2HD
942
943
944
945
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
946
                features=self.num_gqa_groups * self.head_dim,
947
                transpose_batch_sequence=self.transpose_batch_sequence,
948
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
949
950
                use_bias=self.use_bias,
                bias_init=self.bias_init,
951
                bias_axes=(W_TP_AXES,),
952
953
                dtype=self.dtype)
            query, ln_out = LayerNormDenseGeneral(
954
                enable_layernorm=self.input_layernorm,
955
                layernorm_type=self.layernorm_type,
956
                zero_centered_gamma=self.zero_centered_gamma,
957
958
                epsilon=self.layernorm_epsilon,
                axis=-1,
959
                features=self.num_attention_heads * self.head_dim,
960
961
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
962
963
964
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
965
966
                use_bias=self.use_bias,
                bias_init=self.bias_init,
967
                bias_axes=(W_TP_AXES,),
968
969
                dtype=self.dtype,
                kernel_init=query_init,
970
971
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
972
973
                name='query')(inputs_q)

974
            if is_self_attn:
975
976
977
978
979
                assert ln_out is not None
                inputs_kv = ln_out

            key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
980
981
982
983
            query = checkpoint_name(query, 'query_proj')
            key = checkpoint_name(key, 'key_proj')
            value = checkpoint_name(value, 'value_proj')
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
984

985
        if self.enable_rotary_pos_emb:
986
987
988
989
990
991
            if qkv_layout == QKVLayout.BS3HD:
                query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
            elif qkv_layout == QKVLayout.BSHD_BS2HD:
                key, value = jnp.split(kv_proj, [1], axis=-2)
            else:
                assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
992

993
994
995
996
            # No changes to memory layout, should trigger bicast only (Ideally no Perf impact)
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))

997
            query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
998
999
1000
                                   self.transpose_batch_sequence, self.rotary_pos_emb_group_method)
            key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence,
                                 self.rotary_pos_emb_group_method)
1001
            qkv_layout = QKVLayout.BSHD_BSHD_BSHD
1002

1003
1004
        if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
1005
1006
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
1007
1008

        if decode:
1009
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
1010
1011
            is_initialized = self.has_variable('cache', 'cached_key')

Ming-Xu Huang's avatar
Ming-Xu Huang committed
1012
1013
            cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
1014
1015
1016
1017
                                         value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1018
                if self.transpose_batch_sequence:
1019
1020
                    length, batch, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1021
1022
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
1023
1024
                    batch, length, num_attention_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_attention_heads, head_dim)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1025
                    one_hot_indices_shape = (1, length, 1, 1)
1026
1027
1028
1029
1030
1031
1032
1033
1034

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        f"expected query shape {expected_shape} instead got {query.shape}.")

                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1035
1036
1037
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
1038
1039
1040
1041
1042
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1043
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
1044
1045

                if bias is not None:
1046
1047
                    dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim,
                                                       in_axes=(None, 0, None, None))
1048
1049
1050
                    bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
                                                       jnp.reshape(cur_index, (-1)), 1, -2)

1051
1052
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0

1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
        if self.transpose_batch_sequence:
            LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)

        if qkv_layout == QKVLayout.BS3HD:
            qkv_proj = qkv_proj.reshape(*qkv_proj.shape[:2], 3, self.num_attention_heads,
                                        self.head_dim)
            qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint)
            dpa_args = [qkv_proj, None, None]
        elif qkv_layout == QKVLayout.BSHD_BS2HD:
            query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim)
            kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim)
            q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
            kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
            dpa_args = [query, kv_proj, None]
1071
        else:
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
            assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
            query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
            qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES)
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
            dpa_args = [query, key, value]

        x = DotProductAttention(head_dim=self.head_dim,
                                num_attention_heads=self.num_attention_heads,
                                num_gqa_groups=self.num_gqa_groups,
                                attn_mask_type=self.attn_mask_type,
                                attn_bias_type=self.attn_bias_type,
                                attention_dropout=self.attention_dropout,
                                dtype=self.dtype,
                                dropout_rng_name=self.dropout_rng_name,
                                float32_logits=self.float32_logits,
                                qkv_layout=qkv_layout.name,
                                scale_factor=scale_factor,
                                transpose_batch_sequence=self.transpose_batch_sequence)(
                                    *dpa_args, mask, bias, deterministic=deterministic)
1095
1096
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

1097
        attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)
1098
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
1099
1100
1101
1102
1103

        out = DenseGeneral(features=inputs_q.shape[-1],
                           transpose_batch_sequence=self.transpose_batch_sequence,
                           axis=-1,
                           kernel_init=self.kernel_init,
1104
                           kernel_axes=(W_TP_AXES, W_FSDP_AXES),
1105
1106
                           use_bias=self.use_bias,
                           bias_init=self.bias_init,
1107
                           bias_axes=(W_NO_SHARD_AXES,),
1108
1109
                           dtype=self.dtype,
                           name='out')(x)
1110
        out = checkpoint_name(out, 'out_proj')
1111
1112

        return out, ln_out
1113
1114


1115
class RelativePositionBiases(nn.Module):    # pylint: disable=too-few-public-methods
1116
1117
1118
1119
1120
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
1121
    num_buckets: int
1122
        The number of buckets to bucket distances between key and query positions into.
1123
    max_distance: int
1124
        The maximum distance before everything is lumped into the last
1125
        distance bucket.
1126
    num_attention_heads: int
1127
        Number of attention heads in the transformer layer.
1128
    embedding_init: Initializer, default = flax.linen.linear.default_embed_init
1129
        Used for initializing relative embedding tables.
1130
    embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets')
1131
        The name of axes used to shard embedding attention bias with a corresponding mesh.
1132
1133
1134

    Optimization parameters
    -----------------------
1135
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1136
        The data type used to allocate the initial parameters.
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
    """
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
    embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets')
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
1152
        q_seqlen: int
1153
            The sequence length of query.
1154
        k_seqlen: int
1155
            The sequence length of key.
1156
        bidirectional: bool, default = True
1157
            Indicate whether to allow positive memory-query relative position
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) /
            np.log(self.max_distance / rpb_max_exact) *
            (rpb_num_buckets - rpb_max_exact)).astype(np.int32)
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
            'rel_embedding',
            self.embedding_init, (self.num_attention_heads, self.num_buckets),
            jnp.float32,
            axes=self.embedding_axes)

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

        values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot,
                                 (((1,), (0,)), ((), ())))
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
1218
1219
1220
1221
    ENCODER = "encoder"
    DECODER = "decoder"


1222
class TransformerLayer(nn.Module):    # pylint: disable=too-few-public-methods
1223
1224
1225
1226
1227
1228
1229
1230
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
1231
        The hidden size of each input sample.
1232
    mlp_hidden_size: int, default = 2048
1233
        Intermediate size to which input samples are projected.
1234
    num_attention_heads: int, default = 8
1235
        Number of attention heads in the transformer layer.
1236
    num_gqa_groups: int, default = `None`
zlsh80826's avatar
zlsh80826 committed
1237
1238
1239
1240
1241
1242
1243
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
1244
    layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm'
1245
        Indicate the type of layer normalization.
1246
    layernorm_epsilon: float, default = 1e-6
1247
        A value added to the denominator of layer normalization for numerical stability.
1248
    zero_centered_gamma: bool, default = False
1249
1250
1251
1252
1253
1254
1255
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
1256
    hidden_dropout: float, default = 0.1
1257
        Dropout probability for the dropout op after FC2 layer.
1258
    hidden_dropout_dims: Sequence[int], default = ()
1259
        Dimensions that will share the same dropout mask for hidden
1260
    attention_dropout: float, default = 0.1
1261
        Dropout probability for the dropout op during multi-head attention.
1262
1263
1264
1265
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
1266
    dropout_rng_name: str, default = 'dropout'
1267
1268
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
1269
1270
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
1271
1272
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1273
1274
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
1275
1276
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1277
    mlp_activations: Sequence[str], default = ('relu', )
1278
        The sequence of activation functions to apply after the first linear transformation.
1279
1280
        Each activation has its own transformation layer.
    use_bias: bool, default = False
1281
1282
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
1283
    bias_init: Initializer, default = flax.linen.initializers.zeros
1284
1285
1286
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
1287
    apply_residual_connection_post_layernorm: bool, default = False
1288
        If set to True, residual connections are taken from the output
1289
1290
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
1291
        If set to True, layer normalization is applied on the output side,
1292
1293
1294
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
1295
1296
        Whether to compute attention logits in float32 for the unfused attention backend.
        For fused attention backend, the accumulation is always float32 without the perf overhead.
1297
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
1298
        If set to TransformerLayerType.DECODER, an additional cross-attention block
1299
1300
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
1301
1302
1303
    self_attn_mask_type: str, default = 'causal'
        Type of the attention mask passed into softmax operation in the self attention.
        Available options: {'no_mask', 'padding', 'causal', 'causal_padding'}
1304
        Introduced in v0.10.0.
1305
1306
1307
1308
1309
    self_attn_bias_type: Optional[str], default = None
        Type of the attention bias passed into the self attention.
        Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
        When default is present, the type is automatically decided by the MHA's bias parameter.
        Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
1310
    enable_relative_embedding: bool, default = True
1311
        Whether to enable relative embedding as shifting of attention logits.
1312
    relative_embedding: flax.linen.Module, default = None
1313
        The module for relative embedding execution, only used when
1314
1315
1316
1317
1318
1319
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1320
1321
1322
1323
1324
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
1325
1326
1327
1328
    rotary_pos_emb_group_method: str, default = 'consecutive'
        Indicate the method to coupled the coordinates. It should be one of
        ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
        , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
1329
1330
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1331
1332
1333

    Optimization parameters
    -----------------------
1334
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
1335
        The data type used to allocate the initial parameters.
1336
    drop_path: float, default = 0.0
1337
        When > 0.0, applies stochastic depth per sample in the main
1338
1339
        path of the residual block.
    fuse_qkv_params: bool, default = True
1340
        If set to True, `TransformerLayer` module exposes a single fused
1341
1342
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
1343
    transpose_batch_sequence: bool, default = False
1344
        Indicate whether the input tensors were switched axis of batch
1345
1346
1347
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1348
        Indicate whether to scale attention logits.
1349
1350
1351
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1352
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1353
1354
1355
1356
1357
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
1358
    num_gqa_groups: Optional[int] = None
1359
1360
    layernorm_type: str = 'layernorm'
    layernorm_epsilon: float = 1e-6
1361
    zero_centered_gamma: bool = False
1362
1363
1364
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1365
1366
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    dropout_rng_name: str = 'dropout'
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
    mlp_activations: Sequence[str] = ('relu',)
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1377
    self_attn_mask_type: str = 'causal'
1378
    self_attn_bias_type: Optional[str] = None
1379
1380
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1381
1382
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1383
    rotary_pos_emb_group_method: str = 'consecutive'
1384
1385
1386
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1387
    transpose_batch_sequence: bool = False
1388
    enable_sequence_parallel: bool = False
1389
1390
1391
1392
1393
1394
1395
1396
1397
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

    def __post_init__(self):
        if self.mha_kernel_init is None:
            self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
        if self.mlp_kernel_init is None:
            self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
                                                                    'truncated_normal')
zlsh80826's avatar
zlsh80826 committed
1398
1399
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs: Array,
                 encoded: Array = None,
                 attention_mask: Array = None,
                 encoder_decoder_mask: Array = None,
                 deterministic: bool = False,
                 decode: bool = False,
                 max_decode_length: bool = None):
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
1416
        inputs: jax.numpy.ndarray
1417
            Input tensor.
1418
        encoded: jax.numpy.ndarray, default = None
1419
1420
1421
1422
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
1423
        encoder_decoder_mask: jax.numpy.ndarray, default = None
1424
1425
1426
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
        deterministic: bool, default = False
1427
            Disable dropout layers if set to True.
1428
        decode: bool, default = False
1429
1430
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
1431
        max_decode_length: bool, default = None
1432
1433
1434
1435
1436
1437
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
1438
        outputs: jax.numpy.ndarray
1439
            Output tensors.
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
        """
        assert self.layer_type in TransformerLayerType, \
                "layer_type should be one of TransformerLayerType" \
                f", but got {self.layer_type}."

        assert self.hidden_size % self.num_attention_heads == 0, \
                "hidden_size should be multiples of num_attention_heads" \
                f", but got {self.hidden_size=} and {self.num_attention_heads=}."

        assert self.layer_type == TransformerLayerType.DECODER or \
              (self.layer_type == TransformerLayerType.ENCODER and decode is False), \
               "decode should be False when layer_type == TransformerLayerType.ENCODER."

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

            is_shared_seq = self.enable_sequence_parallel if is_shared_seq is None \
                            else is_shared_seq

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
                rel_emb = RelativePositionBiases(num_buckets=32,
                                                 max_distance=128,
                                                 num_attention_heads=self.num_attention_heads,
                                                 dtype=self.dtype,
                                                 embedding_init=nn.initializers.variance_scaling(
                                                     1.0, 'fan_avg', 'uniform'),
                                                 name='relpos_bias')
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
            mha_name = 'attention'
        else:
            mha_name = 'self_attention'

1499
1500
        inputs = with_sharding_constraint_by_logical_axes(
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1501

1502
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1503
1504
1505
        residual = inputs
        x, ln_out = MultiHeadAttention(
            num_attention_heads=self.num_attention_heads,
1506
1507
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1508
            num_gqa_groups=self.num_gqa_groups,
1509
            transpose_batch_sequence=self.transpose_batch_sequence,
1510
            enable_sequence_parallel=self.enable_sequence_parallel,
1511
            attention_dropout=self.attention_dropout,
1512
1513
1514
1515
1516
1517
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1518
            zero_centered_gamma=self.zero_centered_gamma,
1519
1520
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            input_layernorm=not self.output_layernorm,
1521
            attn_mask_type=self.self_attn_mask_type,
1522
            attn_bias_type=self.self_attn_bias_type,
1523
1524
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1525
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1526
            fuse_qkv_params=self.fuse_qkv_params,
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            name=mha_name)(inputs,
                           inputs,
                           attention_mask,
                           attn_bias,
                           deterministic=deterministic,
                           decode=decode)

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1541
                assert -x_shape_len <= dims < x_shape_len
1542
1543

            return nn.Dropout(rate=self.hidden_dropout,
1544
1545
                              broadcast_dims=self.hidden_dropout_dims,
                              rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1546

1547
1548
1549
1550
1551
        x = with_sharding_constraint_by_logical_axes(
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
        residual = with_sharding_constraint_by_logical_axes(
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1552
1553
1554
1555
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
            x = nn.Dropout(rate=self.drop_path,
1556
1557
                           broadcast_dims=drop_path_shape,
                           rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1558
1559
1560
1561
1562

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1563
1564
1565
1566
1567
1568
1569
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
            assert encoded is not None, \
                "encoded is required when layer_type == TransformerLayerType.DECODER."

1570
1571
1572
            x = with_sharding_constraint_by_logical_axes(
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1573
1574
1575
            residual = x
            y, ln_out = MultiHeadAttention(
                num_attention_heads=self.num_attention_heads,
1576
1577
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1578
                num_gqa_groups=self.num_gqa_groups,
1579
                transpose_batch_sequence=self.transpose_batch_sequence,
1580
                enable_sequence_parallel=self.enable_sequence_parallel,
1581
                attention_dropout=self.attention_dropout,
1582
1583
1584
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1585
                zero_centered_gamma=self.zero_centered_gamma,
1586
1587
                return_layernorm_output=self.apply_residual_connection_post_layernorm,
                input_layernorm=True,    # Must do LayerNorm before MHA.
1588
                attn_mask_type='padding',
1589
                attn_bias_type='no_bias',
1590
1591
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1592
                rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
1593
1594
1595
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
1596
                fuse_qkv_params=self.fuse_qkv_params,
1597
1598
1599
1600
1601
1602
1603
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
                name='encoder_decoder_attention')(x,
                                                  encoded,
                                                  encoder_decoder_mask,
                                                  deterministic=deterministic)
1604
1605
1606
1607
1608
1609

            y = with_sharding_constraint_by_logical_axes(
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
            residual = with_sharding_constraint_by_logical_axes(
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1610
            y = hidden_dropout(y, deterministic)
1611
1612
1613
1614
1615

            if self.apply_residual_connection_post_layernorm:
                assert ln_out is not None
                residual = ln_out

1616
1617
            mlp_input = y + residual

1618
1619
        mlp_input = with_sharding_constraint_by_logical_axes(
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1620

1621
1622
1623
1624
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
1625
            zero_centered_gamma=self.zero_centered_gamma,
1626
1627
1628
1629
1630
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
1631
1632
1633
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
1634
            dtype=self.dtype,
1635
1636
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
1637
            kernel_init=self.mlp_kernel_init,
1638
1639
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
1640
1641
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1642
1643
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
1644
1645
1646
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
1647
1648
1649
1650
1651
1652
1653
            name='mlp',
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1654
1655
1656
1657
1658
        z = with_sharding_constraint_by_logical_axes(
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
        residual = with_sharding_constraint_by_logical_axes(
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1659
1660
1661
1662
1663
1664
1665
1666
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
            z = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
        z = z + residual

        if self.output_layernorm:
1667
1668
            z = with_sharding_constraint_by_logical_axes(
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1669
            z = LayerNorm(layernorm_type=self.layernorm_type,
1670
1671
                          zero_centered_gamma=self.zero_centered_gamma,
                          epsilon=self.layernorm_epsilon,
1672
1673
                          scale_axes=(W_NO_SHARD_AXES,),
                          bias_axes=(W_NO_SHARD_AXES,),
1674
1675
                          transpose_batch_sequence=self.transpose_batch_sequence,
                          dtype=self.dtype,
1676
                          name="output_layernorm")(z)
1677
1678

        return z