transformer.py 60.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
19
20
21
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
27
28
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
29
from ..softmax import SoftmaxType
30
31
32
33
34
35
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes
from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES
from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES
from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
56
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
57
58
59
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
60
61
62
63
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
64
65
66
67

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
68
69
70
71
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    Parameters
    ----------
    rules : Sequence[Tuple[str, Union[str, None]]]
        the base Flax logical axis rules to extend.

    Returns
    -------
    extended_rules : Sequence[Tuple[str, Union[str, None]]]
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
        assert len(item) == 2, \
            "The logical axis rule should be like (axis_name, mesh_axis_name)."
        key = item[0]
        val = item[1]
        assert isinstance(key, str), \
            f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (val is None), \
            f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
92
93
94
95
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
96
97

    extended_rules = [*rules]
98
    for item in get_sharding_map_logic_axis_to_mesh_axis().items():
99
100
101
        key = item[0]
        val = item[1]
        if key in rules_map:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
102
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
                f"The rule diverged between TE and given rule." \
                f"Axis:{key} map to {rules_map[key]} in the given" \
                f" rules, but {val} in TE's rules."
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


def _merge_mask(func, *masks: Optional[Array]):
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = func(mask, other_mask)
    return mask


def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
    """Combine attention masks."""
    func = jnp.logical_and
    return _merge_mask(func, *masks).astype(dtype)


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases."""
131
132
133
134

    def func(a, b):
        return a + b

135
136
137
138
139
140
    return _merge_mask(func, *masks)


def core_attention(query: Array,
                   key: Array,
                   value: Array,
141
                   scale_factor: float,
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                   transpose_batch_sequence: bool,
                   softmax_type: SoftmaxType = SoftmaxType.SCALED,
                   mask: Optional[Array] = None,
                   bias: Optional[Array] = None,
                   dropout_rng: Optional[PRNGKey] = None,
                   dropout_rate: float = 0.,
                   deterministic: bool = False,
                   dtype: DType = jnp.float32,
                   float32_logits: bool = False):
    """Core attention"""
    assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
    batch_dim = 1 if transpose_batch_sequence else 0
    assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
        'q, k, v batch dims must match.')
    sequence_dim = 0 if transpose_batch_sequence else 1
    assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
zlsh80826's avatar
zlsh80826 committed
158
159
    assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
    assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
160
161
162
163
164

    if float32_logits:
        query = query.astype(jnp.float32)
        key = key.astype(jnp.float32)

zlsh80826's avatar
zlsh80826 committed
165
    h_q, h_kv = query.shape[-2], key.shape[-2]
166
167
168
169
170
171
172
173
    # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
    # Therefore, we have to maintain two code paths.
    is_gqa = (h_q != h_kv)

    if is_gqa:
        assert (h_q % h_kv == 0) and (h_q >= h_kv)
        group_size = h_q // h_kv
        grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
zlsh80826's avatar
zlsh80826 committed
174

175
    if transpose_batch_sequence:
176
177
178
179
        if is_gqa:
            attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
        else:
            attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
180
    else:
181
182
183
184
        if is_gqa:
            attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
        else:
            attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
185

186
187
    attn_weights = checkpoint_name(attn_weights, 'logits')

188
189
190
191
    if is_gqa:
        b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
        attn_weights_without_groups_shape = (b, h * g, q, k)
        attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
zlsh80826's avatar
zlsh80826 committed
192

193
194
    attn_weights = with_sharding_constraint_by_logical_axes(
        attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
195

196
197
198
199
200
201
202
203
204
    # When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
    # In this case, the scale can not fused into the Softmax module.
    if bias is not None:
        attn_weights = attn_weights * scale_factor
        fused_scale_factor = 1.
    else:
        # If no bias, the scale can be fused into Softmax module
        fused_scale_factor = scale_factor

205
    attn_weights = Softmax(softmax_type=softmax_type,
206
                           scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
207

208
209
    if is_gqa:
        attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
zlsh80826's avatar
zlsh80826 committed
210

211
212
213
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        dropout_shape = list(attn_weights.shape)
214
        # TODO(rewang): add attention dropout broadcast dimension arguments for users
215
216
217
218
219
        keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    if transpose_batch_sequence:
220
221
222
        if is_gqa:
            return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
        return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
223

224
225
226
    if is_gqa:
        return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
    return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
227
228


229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool):
    """
    Rotary Positional Embedding
    x should be in shape of
    [Batch, Seqlen, ..., Hidden] if transpose_batch_sequence is False, or
    [Seqlen, Batch, ..., Hidden] if transpose_batch_sequence is True.
    """
    embed_dim = x.shape[-1]
    half_embed_dim = embed_dim // 2
    min_window = windows[0]
    max_window = windows[1]

    fraction = 2 * jnp.arange(0, half_embed_dim) / embed_dim
    time_scales = min_window * (max_window / min_window)**fraction
    time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))

    batch_dim = 1 if transpose_batch_sequence else 0
    seq_dim = 1 - batch_dim

    positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
    positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))

    sinusoidal_positions = positions / time_scales
    sin = jnp.sin(sinusoidal_positions)
    cos = jnp.cos(sinusoidal_positions)

    x1, x2 = jnp.split(x, 2, axis=-1)
    part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
    part_2 = (x2 * cos + x1 * sin).astype(x.dtype)

    return jnp.concatenate([part_1, part_2], axis=-1)


262
263
264
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


265
class MultiHeadAttention(nn.Module):    # pylint: disable=too-few-public-methods
266
267
268
269
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

270
    .. note::
271

272
273
        Argument :attr:`mask` will be ignored when
        :attr:`attn_mask_type` is set to `"causal"`.
274

275
276
277
    Parameters
    ----------
    head_dim : int
278
        The hidden dimension of each attention head.
279
    num_heads : int
280
        The number of attention heads
zlsh80826's avatar
zlsh80826 committed
281
282
283
284
285
286
287
288
    num_gqa_groups : int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
289
    dropout_rate : float, default = 0.0
290
        Dropout probability for the dropout op during multi-head attention.
291
    dropout_rng_name: str, default = 'dropout'
292
293
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
294
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
295
        Indicate the type of layer normalization.
296
    layernorm_epsilon: float, default = 1e-6
297
        A value added to the denominator of layer normalization for numerical stability.
298
299
300
301
302
303
304
305
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
306
307
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
308
309
        Used for initializing the QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
310
    use_bias: bool, default = False
311
312
        Indicate whether or not to enable bias shifting for QKVO projections.
        If set to False, the layer will not learn additive biases.
313
    bias_init: Initializer, default = flax.linen.initializers.zeros
314
315
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
316
    apply_residual_connection_post_layernorm : bool, default = False
317
        Indicate if apply residual connection with the output of layer normalization.
318
    output_layernorm : bool, default = False
319
        Indicate if apply a layer normalization at the end of MHA.
320
321
    attn_mask_type: {'causal', 'padding'}, default = 'causal'
        Type of attention mask passed into softmax operation.
322
        Introduced in v0.10.0.
323
324
325
326
327
328
329
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
330
331
332
333

    Optimization parameters
    -----------------------
    dtype :jax.numpy.dtype, default  = jax.numpy.float32
334
        The data type used to allocate the initial parameters.
335
    fuse_qkv: bool, default = True
336
        If set to True, this module exposes a single fused
337
338
339
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
    transpose_batch_sequence : bool, default = True
340
        Indicate whether the input tensors were switched axis of batch
341
342
343
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
344
345
        Indicate whether to scale attention logits.
        If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
346
347
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
348
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
349
    float32_logits : bool, default = False
350
        Whether to compute attention logits in float32.
351
352
353
354
    """

    head_dim: int
    num_heads: int
zlsh80826's avatar
zlsh80826 committed
355
    num_gqa_groups: int | None = None
356
357
358
359
    dropout_rate: float = 0.
    dropout_rng_name: str = 'dropout'
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
360
    zero_centered_gamma: bool = False
361
362
363
364
365
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
366
    attn_mask_type: str = 'causal'
367
368
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
369
370
371
    dtype: DType = jnp.float32
    fuse_qkv: bool = True
    transpose_batch_sequence: bool = True
372
    enable_sequence_parallel: bool = False
373
374
375
376
377
378
379
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False    # computes logits in float32 for stability.

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
zlsh80826's avatar
zlsh80826 committed
380
381
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_heads
382
383
384
385
386
387
388
389
390
391
392
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs_q: Array,
                 inputs_kv: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> Array:
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
        inputs_q : jax.numpy.ndarray
            Input tensor for query projection.
        inputs_kv : jax.numpy.ndarray
            Input tensor for key/value projection.
        mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
        bias : jax.numpy.ndarray, default = None
            A tensor used to shift self-attention softmax input.
        *
        decode : bool,default = False
            Indicate whether to prepare and use an autoregressive cache.
        deterministic : bool,default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
418
419

        def query_init(*args):
420
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

453
454
455
456
457
458
459
460
        # TODO(rewang): make it configurable for pre_scale_bias
        attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS

        def canonicalize_attn_mask_type(attn_mask_type):
            """
            Convert the string to AttnMaskType
            """
            if attn_mask_type == 'causal':
461
                return AttnMaskType.PADDING_CAUSAL_MASK
462
463
464
465
466
            if attn_mask_type == 'padding':
                return AttnMaskType.PADDING_MASK
            raise ValueError(f"Unsupported {attn_mask_type=}, "
                             "supported attn_mask_type = {'causal', 'padding'}")

467
        is_self_attn = (inputs_q is inputs_kv)
zlsh80826's avatar
zlsh80826 committed
468
469
        is_gqa = (self.num_heads != self.num_gqa_groups)
        is_qkvpack = (is_self_attn and not is_gqa)
470
        qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
471
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
472

473
474
        q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
        kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
475
        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
476

477
        has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
478
                                                               attn_bias_type, attn_mask_type,
zlsh80826's avatar
zlsh80826 committed
479
480
481
                                                               self.dropout_rate, self.num_heads,
                                                               self.num_gqa_groups, q_seqlen,
                                                               kv_seqlen, self.head_dim)
482

483
        use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
484
            has_fused_attn_kernel and \
485
            enable_fused_attn
486

487
        if enable_fused_attn and not use_fused_attn:
488
489
490
491
492
493
494
495
            reason = ""
            if decode:
                reason += f"decode=False is required but got {decode}, "
            if self.transpose_batch_sequence:
                reason += f"transpose_batch_sequence=False is required " \
                          f"but got {self.transpose_batch_sequence}, "
            if not self.fuse_qkv:
                reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
496
497
            if not has_fused_attn_kernel:
                reason += "no fused attention kernel is available, "
498
499

            warnings.warn(
500
501
502
                f"Fused attention is not enabled. Because " \
                f"{reason}fall back to unfused attention.")

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        def generate_batch_seqlen_logical_axes(is_sharded_seq):
            sequence_dim = 0 if self.transpose_batch_sequence else 1
            batch_dim = 1 - sequence_dim

            axes = [None, None]

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES
            return tuple(axes)

        inputs_logical_axes_maybe_sp = (*generate_batch_seqlen_logical_axes(
            self.enable_sequence_parallel), HIDDEN_AXES)
        inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES)

        inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)

519
520
        residual = inputs_q
        if self.fuse_qkv:
zlsh80826's avatar
zlsh80826 committed
521
            if is_qkvpack:
522
523
524
                qkv_proj, ln_out = LayerNormDenseGeneral(
                    enable_layernorm=not self.output_layernorm,
                    layernorm_type=self.layernorm_type,
525
                    zero_centered_gamma=self.zero_centered_gamma,
526
527
528
529
530
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
                    features=(3, self.num_heads * self.head_dim),
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    return_layernorm_output=self.apply_residual_connection_post_layernorm,
531
532
533
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
534
535
536
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
537
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
538
539
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
540
541
                    name='qkv',
                    dtype=self.dtype)(inputs_q)
542
                qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
543
544
                if not use_fused_attn:
                    query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
545
546
547
548
            else:
                query, ln_out = LayerNormDenseGeneral(
                    enable_layernorm=not self.output_layernorm,
                    layernorm_type=self.layernorm_type,
549
                    zero_centered_gamma=self.zero_centered_gamma,
550
551
552
553
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
                    features=self.num_heads * self.head_dim,
                    transpose_batch_sequence=self.transpose_batch_sequence,
zlsh80826's avatar
zlsh80826 committed
554
555
                    return_layernorm_output=(self.apply_residual_connection_post_layernorm
                                             or is_self_attn),
556
557
558
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
559
560
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
561
                    bias_axes=(W_TP_AXES,),
562
563
                    dtype=self.dtype,
                    kernel_init=query_init,
564
565
                    layernorm_input_axes=inputs_logical_axes_maybe_sp,
                    dot_input_axes=inputs_logical_axes_no_sp,
566
                    name='query')(inputs_q)
zlsh80826's avatar
zlsh80826 committed
567
568
569
570
571

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

572
                kv_proj = DenseGeneral(axis=-1,
zlsh80826's avatar
zlsh80826 committed
573
                                       features=(2, self.num_gqa_groups * self.head_dim),
574
                                       transpose_batch_sequence=self.transpose_batch_sequence,
575
                                       kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
576
577
578
                                       kernel_init=kv_init,
                                       use_bias=self.use_bias,
                                       bias_init=self.bias_init,
579
                                       bias_axes=(W_JOINED_AXES, W_TP_AXES),
580
581
                                       name='kv',
                                       dtype=self.dtype)(inputs_kv)
582
                kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
583
584
                if not use_fused_attn:
                    key, value = jnp.split(kv_proj, [1], axis=-2)
585
586
587
588
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
589
                features=self.num_gqa_groups * self.head_dim,
590
                transpose_batch_sequence=self.transpose_batch_sequence,
591
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
592
593
                use_bias=self.use_bias,
                bias_init=self.bias_init,
594
                bias_axes=(W_TP_AXES,),
595
596
597
598
                dtype=self.dtype)
            query, ln_out = LayerNormDenseGeneral(
                enable_layernorm=not self.output_layernorm,
                layernorm_type=self.layernorm_type,
599
                zero_centered_gamma=self.zero_centered_gamma,
600
601
602
603
604
                epsilon=self.layernorm_epsilon,
                axis=-1,
                features=self.num_heads * self.head_dim,
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
605
606
607
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
608
609
                use_bias=self.use_bias,
                bias_init=self.bias_init,
610
                bias_axes=(W_TP_AXES,),
611
612
                dtype=self.dtype,
                kernel_init=query_init,
613
614
                layernorm_input_axes=inputs_logical_axes_maybe_sp,
                dot_input_axes=inputs_logical_axes_no_sp,
615
616
                name='query')(inputs_q)

617
            if is_self_attn:
618
619
620
621
622
623
624
625
626
627
                assert ln_out is not None
                inputs_kv = ln_out

            key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        if self.enable_rotary_pos_emb:
            if self.fuse_qkv and use_fused_attn:
                if is_qkvpack:
                    query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
                else:
                    key, value = jnp.split(kv_proj, [1], axis=-2)

            query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
                                   self.transpose_batch_sequence)
            key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence)

            if use_fused_attn:
                if is_qkvpack:
                    qkv_proj = jnp.concatenate([query, key, value], axis=-2)
                else:
                    kv_proj = jnp.concatenate([key, value], axis=-2)

645
        if not use_fused_attn:
646
647
648
            query = checkpoint_name(query, 'query_proj')
            key = checkpoint_name(key, 'key_proj')
            value = checkpoint_name(value, 'value_proj')
zlsh80826's avatar
zlsh80826 committed
649
650
651
            query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
652
            qkv_sharding_constraint = \
653
                (SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
654
                if self.transpose_batch_sequence \
655
                else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
656
657
658
            query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint)
            key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint)
            value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
659
660
661
662

        if decode:
            is_initialized = self.has_variable('cache', 'cached_key')

Ming-Xu Huang's avatar
Ming-Xu Huang committed
663
664
            cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
665
666
667
668
                                         value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
669
670
671
672
673
674
675
676
                if self.transpose_batch_sequence:
                    length, batch, num_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_heads, head_dim)
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
                    batch, length, num_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_heads, head_dim)
                    one_hot_indices_shape = (1, length, 1, 1)
677
678
679
680
681
682
683
684
685

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        f"expected query shape {expected_shape} instead got {query.shape}.")

                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
686
687
688
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
689
690
691
692
693
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
694
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
695
696
697
698
699

                if bias is not None:
                    bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
                                                       jnp.reshape(cur_index, (-1)), 1, -2)

700
701
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0

702
703
704
705
        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng(self.dropout_rng_name)

706
707
708
        if use_fused_attn:
            assert mask is not None and mask.ndim == 4    # (b, 1, s_q, s_kv)
            assert not self.transpose_batch_sequence
709

710
711
            seed = None
            if dropout_rng is not None:
712
                seed = jax.random.split(dropout_rng, num_of_devices())
713
714
715
                # ensure the old key never used
                del dropout_rng

zlsh80826's avatar
zlsh80826 committed
716
            if is_qkvpack:
717
                qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
718
719
                qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
                                           HIDDEN_AXES)
720
721
722
                qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj,
                                                                    qkv_sharding_constraint)

723
724
725
                x = self_fused_attn(qkv_proj,
                                    bias,
                                    mask,
726
                                    seed,
727
                                    attn_bias_type=attn_bias_type,
728
                                    attn_mask_type=attn_mask_type,
729
730
                                    scaling_factor=scale_factor,
                                    dropout_probability=self.dropout_rate,
731
                                    is_training=not deterministic)
732
733
734
            else:
                assert bias is None
                query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
735
                kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
736
737
738
                q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
                kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
                                          HIDDEN_AXES)
739
740
                query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint)
                kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint)
741
742
743

                x = cross_fused_attn(query,
                                     kv_proj,
744
                                     bias,
745
                                     mask,
746
                                     seed,
747
                                     attn_bias_type=attn_bias_type,
748
                                     attn_mask_type=attn_mask_type,
749
750
                                     scaling_factor=scale_factor,
                                     dropout_probability=self.dropout_rate,
751
                                     is_training=not deterministic)
752
        else:
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767

            def convert_to_softmax_type(attn_mask_type, mask):
                """
                Convert the string to SoftmaxType
                """
                if attn_mask_type == 'causal':
                    return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
                if attn_mask_type == 'padding':
                    if mask is not None:
                        return SoftmaxType.SCALED_MASKED
                    return SoftmaxType.SCALED
                raise ValueError(f"Unsupported {attn_mask_type=}, "
                                 "supported attn_mask_type = {'causal', 'padding'}")

            softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
768
769
770
771
772
773
774
775
776
777
778
779
780
781

            x = core_attention(query,
                               key,
                               value,
                               scale_factor=scale_factor,
                               transpose_batch_sequence=self.transpose_batch_sequence,
                               softmax_type=softmax_type,
                               mask=mask,
                               bias=bias,
                               dropout_rng=dropout_rng,
                               dropout_rate=self.dropout_rate,
                               deterministic=deterministic,
                               dtype=self.dtype,
                               float32_logits=self.float32_logits)
782

783
784
            x = checkpoint_name(x, 'context')

785
786
787
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        attn_context_sharding_constraint = \
788
            (SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
789
            if self.transpose_batch_sequence \
790
            else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
791
        x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint)
792
793
794
795
796

        out = DenseGeneral(features=inputs_q.shape[-1],
                           transpose_batch_sequence=self.transpose_batch_sequence,
                           axis=-1,
                           kernel_init=self.kernel_init,
797
                           kernel_axes=(W_TP_AXES, W_FSDP_AXES),
798
799
                           use_bias=self.use_bias,
                           bias_init=self.bias_init,
800
                           bias_axes=(W_NO_SHARD_AXES,),
801
802
                           dtype=self.dtype,
                           name='out')(x)
803
        out = checkpoint_name(out, 'out_proj')
804
805
806
        return out, residual


807
class RelativePositionBiases(nn.Module):    # pylint: disable=too-few-public-methods
808
809
810
811
812
813
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
    num_buckets : int
814
        The number of buckets to bucket distances between key and query positions into.
815
    max_distance : int
816
        The maximum distance before everything is lumped into the last
817
818
        distance bucket.
    num_attention_heads : int
819
        Number of attention heads in the transformer layer.
820
    embedding_init : Initializer, default = flax.linen.linear.default_embed_init
821
        Used for initializing relative embedding tables.
822
    embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets')
823
        The name of axes used to shard embedding attention bias with a corresponding mesh.
824
825
826
827

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
828
        The data type used to allocate the initial parameters.
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
    """
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
    embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets')
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
        q_seqlen : int
845
            The sequence length of query.
846
        k_seqlen : int
847
            The sequence length of key.
848
        bidirectional : bool, default = True
849
            Indicate whether to allow positive memory-query relative position
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) /
            np.log(self.max_distance / rpb_max_exact) *
            (rpb_num_buckets - rpb_max_exact)).astype(np.int32)
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
            'rel_embedding',
            self.embedding_init, (self.num_attention_heads, self.num_buckets),
            jnp.float32,
            axes=self.embedding_axes)

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

        values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot,
                                 (((1,), (0,)), ((), ())))
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
900
901
902
903
904
905
906
907
908
909
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
910
911
912
913
    ENCODER = "encoder"
    DECODER = "decoder"


914
class TransformerLayer(nn.Module):    # pylint: disable=too-few-public-methods
915
916
917
918
919
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

920
921
922
923
924
    .. note::

        Argument :attr:`attention_mask` will be ignored when
        :attr:`self_attn_mask_type` is set to `"causal"`.

925
926
927
    Parameters
    ----------
    hidden_size: int, default = 512
928
        The hidden size of each input sample.
929
    mlp_hidden_size: int, default = 2048
930
        Intermediate size to which input samples are projected.
931
    num_attention_heads: int, default = 8
932
        Number of attention heads in the transformer layer.
zlsh80826's avatar
zlsh80826 committed
933
934
935
936
937
938
939
940
    num_gqa_groups : int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
941
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
942
        Indicate the type of layer normalization.
943
    layernorm_epsilon: float, default = 1e-6
944
        A value added to the denominator of layer normalization for numerical stability.
945
946
947
948
949
950
951
952
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
953
    hidden_dropout: float, default = 0.1
954
        Dropout probability for the dropout op after FC2 layer.
955
    hidden_dropout_dims: Sequence[int], default = ()
956
        Dimensions that will share the same dropout mask for hidden
957
    attention_dropout: float, default = 0.1
958
        Dropout probability for the dropout op during multi-head attention.
959
960
961
962
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
963
    dropout_rng_name: str, default = 'dropout'
964
965
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
966
967
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
968
969
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
970
971
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
972
973
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
974
    mlp_activations: Sequence[str], default = ('relu', )
975
        The sequence of activation functions to apply after the first linear transformation.
976
977
        Each activation has its own transformation layer.
    use_bias: bool, default = False
978
979
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
980
    bias_init: Initializer, default = flax.linen.initializers.zeros
981
982
983
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
984
    apply_residual_connection_post_layernorm: bool, default = False
985
        If set to True, residual connections are taken from the output
986
987
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
988
        If set to True, layer normalization is applied on the output side,
989
990
991
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
992
        If set to True, attention logits are executed in jax.numpy.float32.
993
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
994
        If set to TransformerLayerType.DECODER, an additional cross-attention block
995
996
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
997
998
    self_attn_mask_type: {'causal', 'padding'}, default = 'causal'
        Type of attention mask passed into softmax operation.
999
        Introduced in v0.10.0.
1000
    enable_relative_embedding: bool, default = True
1001
        Whether to enable relative embedding as shifting of attention logits.
1002
    relative_embedding: flax.linen.Module, default = None
1003
        The module for relative embedding execution, only used when
1004
1005
1006
1007
1008
1009
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')
1010
1011
1012
1013
1014
1015
1016
    enable_rotary_pos_emb: bool, default = False
        Whether to enable rotary position embedding to projected query and key in MHA.
    rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
        Indicate the min and max time-scales of rotary position embedding,
        only used when :attr:`enable_rotary_pos_emb=True`
    enable_sequence_parallel: bool, default = False
        Whether to enable sequence parallelism to operations except dot.
1017
1018
1019
1020

    Optimization parameters
    -----------------------
    dtype :jax.numpy.dtype, default  = jax.numpy.float32
1021
        The data type used to allocate the initial parameters.
1022
    drop_path: float, default = 0.0
1023
        When > 0.0, applies stochastic depth per sample in the main
1024
1025
        path of the residual block.
    fuse_qkv_params: bool, default = True
1026
        If set to True, `TransformerLayer` module exposes a single fused
1027
1028
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1029
    transpose_batch_sequence : bool, default = False
1030
        Indicate whether the input tensors were switched axis of batch
1031
1032
1033
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
1034
        Indicate whether to scale attention logits.
1035
1036
1037
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
1038
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
1039
1040
1041
1042
1043
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
zlsh80826's avatar
zlsh80826 committed
1044
    num_gqa_groups: int | None = None
1045
1046
    layernorm_type: str = 'layernorm'
    layernorm_epsilon: float = 1e-6
1047
    zero_centered_gamma: bool = False
1048
1049
1050
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1051
1052
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
    dropout_rng_name: str = 'dropout'
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
    mlp_activations: Sequence[str] = ('relu',)
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1063
    self_attn_mask_type: str = 'causal'
1064
1065
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
1066
1067
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
1068
1069
1070
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1071
    transpose_batch_sequence: bool = False
1072
    enable_sequence_parallel: bool = False
1073
1074
1075
1076
1077
1078
1079
1080
1081
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

    def __post_init__(self):
        if self.mha_kernel_init is None:
            self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
        if self.mlp_kernel_init is None:
            self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
                                                                    'truncated_normal')
zlsh80826's avatar
zlsh80826 committed
1082
1083
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs: Array,
                 encoded: Array = None,
                 attention_mask: Array = None,
                 encoder_decoder_mask: Array = None,
                 deterministic: bool = False,
                 decode: bool = False,
                 max_decode_length: bool = None):
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensor.
        encoded : jax.numpy.ndarray, default = None
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
        encoder_decoder_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
        deterministic: bool, default = False
1111
            Disable dropout layers if set to True.
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        decode: bool,default = False
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
        max_decode_length : bool, default = None
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
        outputs : jax.numpy.ndarray
1123
            Output tensors.
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
        """
        assert self.layer_type in TransformerLayerType, \
                "layer_type should be one of TransformerLayerType" \
                f", but got {self.layer_type}."

        assert self.hidden_size % self.num_attention_heads == 0, \
                "hidden_size should be multiples of num_attention_heads" \
                f", but got {self.hidden_size=} and {self.num_attention_heads=}."

        assert self.layer_type == TransformerLayerType.DECODER or \
              (self.layer_type == TransformerLayerType.ENCODER and decode is False), \
               "decode should be False when layer_type == TransformerLayerType.ENCODER."

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
        def generate_batch_seqlen_logical_axes(is_shared_seq=None):
            axes = [None, None]

            is_shared_seq = self.enable_sequence_parallel if is_shared_seq is None \
                            else is_shared_seq

            axes[batch_dim] = BATCH_AXES
            axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES
            return tuple(axes)

1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
                rel_emb = RelativePositionBiases(num_buckets=32,
                                                 max_distance=128,
                                                 num_attention_heads=self.num_attention_heads,
                                                 dtype=self.dtype,
                                                 embedding_init=nn.initializers.variance_scaling(
                                                     1.0, 'fan_avg', 'uniform'),
                                                 name='relpos_bias')
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
            mha_name = 'attention'
        else:
            mha_name = 'self_attention'

1183
1184
        inputs = with_sharding_constraint_by_logical_axes(
            inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1185

1186
1187
1188
1189
1190
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        x, residual = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1191
            num_gqa_groups=self.num_gqa_groups,
1192
            transpose_batch_sequence=self.transpose_batch_sequence,
1193
            enable_sequence_parallel=self.enable_sequence_parallel,
1194
1195
1196
1197
1198
1199
1200
            dropout_rate=self.attention_dropout,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1201
            zero_centered_gamma=self.zero_centered_gamma,
1202
1203
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
1204
            attn_mask_type=self.self_attn_mask_type,
1205
1206
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
            fuse_qkv=self.fuse_qkv_params,
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            name=mha_name)(inputs,
                           inputs,
                           attention_mask,
                           attn_bias,
                           deterministic=deterministic,
                           decode=decode)

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1222
                assert -x_shape_len <= dims < x_shape_len
1223
1224

            return nn.Dropout(rate=self.hidden_dropout,
1225
1226
                              broadcast_dims=self.hidden_dropout_dims,
                              rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1227

1228
1229
1230
1231
1232
        x = with_sharding_constraint_by_logical_axes(
            x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
        residual = with_sharding_constraint_by_logical_axes(
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1233
1234
1235
1236
        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
            x = nn.Dropout(rate=self.drop_path,
1237
1238
                           broadcast_dims=drop_path_shape,
                           rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1239
1240
1241
1242
1243
1244
1245
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
            assert encoded is not None, \
                "encoded is required when layer_type == TransformerLayerType.DECODER."

1246
1247
1248
            x = with_sharding_constraint_by_logical_axes(
                x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1249
1250
1251
1252
            y, residual = MultiHeadAttention(
                num_heads=self.num_attention_heads,
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1253
                num_gqa_groups=self.num_gqa_groups,
1254
                transpose_batch_sequence=self.transpose_batch_sequence,
1255
                enable_sequence_parallel=self.enable_sequence_parallel,
1256
1257
1258
1259
                dropout_rate=self.attention_dropout,
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1260
                zero_centered_gamma=self.zero_centered_gamma,
1261
1262
1263
                apply_residual_connection_post_layernorm=self.
                apply_residual_connection_post_layernorm,
                output_layernorm=False,    # Must do LayerNorm before MHA.
1264
                attn_mask_type='padding',
1265
1266
                enable_rotary_pos_emb=self.enable_rotary_pos_emb,
                rotary_pos_emb_windows=self.rotary_pos_emb_windows,
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
                fuse_qkv=self.fuse_qkv_params,
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
                name='encoder_decoder_attention')(x,
                                                  encoded,
                                                  encoder_decoder_mask,
                                                  deterministic=deterministic)
1278
1279
1280
1281
1282
1283

            y = with_sharding_constraint_by_logical_axes(
                y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
            residual = with_sharding_constraint_by_logical_axes(
                residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1284
1285
1286
            y = hidden_dropout(y, deterministic)
            mlp_input = y + residual

1287
1288
        mlp_input = with_sharding_constraint_by_logical_axes(
            mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1289

1290
1291
1292
1293
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
1294
            zero_centered_gamma=self.zero_centered_gamma,
1295
1296
1297
1298
1299
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
1300
1301
1302
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
1303
            dtype=self.dtype,
1304
1305
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
1306
            kernel_init=self.mlp_kernel_init,
1307
1308
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
1309
1310
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1311
1312
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
1313
1314
1315
            layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
            dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
            dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
1316
1317
1318
1319
1320
1321
1322
            name='mlp',
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

1323
1324
1325
1326
1327
        z = with_sharding_constraint_by_logical_axes(
            z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
        residual = with_sharding_constraint_by_logical_axes(
            residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))

1328
1329
1330
1331
1332
1333
1334
1335
        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
            z = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
        z = z + residual

        if self.output_layernorm:
1336
1337
            z = with_sharding_constraint_by_logical_axes(
                z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
1338
            z = LayerNorm(layernorm_type=self.layernorm_type,
1339
1340
                          zero_centered_gamma=self.zero_centered_gamma,
                          epsilon=self.layernorm_epsilon,
1341
1342
                          scale_axes=(W_NO_SHARD_AXES,),
                          bias_axes=(W_NO_SHARD_AXES,),
1343
1344
1345
1346
1347
                          transpose_batch_sequence=self.transpose_batch_sequence,
                          dtype=self.dtype,
                          name="output_layer_norm")(z)

        return z