transformer.py 55 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
import os
11
from typing import Any, Callable, Optional, Sequence, Tuple, Union
12
import warnings
13

14
import jax
15
16
17
18
19
20
21
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
22
from jax.ad_checkpoint import checkpoint_name
23
24
25

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
26
from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout
27
28
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
29
from ..softmax import SoftmaxType
30
31
from ..sharding import global_mesh_resource, num_of_devices
from ..sharding import with_sharding_constraint
32
33
34
35
36
37
38
39
40
41

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]

42
43
44
45
46
47
48
49
50
51
52
BATCH_AXES = 'nvte_batch'
SEQLEN_AXES = 'nvte_seqlen'
HEAD_AXES = 'nvte_head'
HIDDEN_AXES = 'nvte_hidden'
HIDDEN_TP_AXES = 'nvte_hidden_tp'
JOINED_AXES = 'nvte_joined'
W_NO_SHARD_AXES = 'nvte_w_no_shard'
W_FSDP_AXES = 'nvte_w_fsdp'
W_TP_AXES = 'nvte_w_tp'
W_JOINED_AXES = 'nvte_w_joined'

53
54
55
56
57
58
59
60
61
62

def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
63
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
64
65
66
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
67
68
69
70
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
71
72
73
74

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
75
76
77
78
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    Parameters
    ----------
    rules : Sequence[Tuple[str, Union[str, None]]]
        the base Flax logical axis rules to extend.

    Returns
    -------
    extended_rules : Sequence[Tuple[str, Union[str, None]]]
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
        assert len(item) == 2, \
            "The logical axis rule should be like (axis_name, mesh_axis_name)."
        key = item[0]
        val = item[1]
        assert isinstance(key, str), \
            f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (val is None), \
            f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
99
100
101
102
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
103

104
    gsr = global_mesh_resource()
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

    batch_dim_rule = []
    if gsr.dp_resource is not None:
        batch_dim_rule.append(gsr.dp_resource)
    if gsr.fsdp_resource is not None and gsr.dp_resource != gsr.fsdp_resource:
        batch_dim_rule.append(gsr.fsdp_resource)

    if len(batch_dim_rule) <= 0:
        batch_dim_rule = None
    elif len(batch_dim_rule) == 1:
        batch_dim_rule = batch_dim_rule[0]
    else:
        batch_dim_rule = tuple(batch_dim_rule)

    te_logical_axis_rules = (
        (BATCH_AXES, batch_dim_rule),
        (SEQLEN_AXES, None),
        (HEAD_AXES, gsr.tp_resource),
        (HIDDEN_AXES, None),
        (HIDDEN_TP_AXES, gsr.tp_resource),
        (JOINED_AXES, None),
        (W_NO_SHARD_AXES, None),
        (W_FSDP_AXES, gsr.fsdp_resource),
        (W_TP_AXES, gsr.tp_resource),
        (W_JOINED_AXES, None),
    )
131
132
133
134
135
136

    extended_rules = [*rules]
    for item in te_logical_axis_rules:
        key = item[0]
        val = item[1]
        if key in rules_map:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
137
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
138
139
140
141
142
143
144
145
                f"The rule diverged between TE and given rule." \
                f"Axis:{key} map to {rules_map[key]} in the given" \
                f" rules, but {val} in TE's rules."
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


146
147
148
149
150
151
152
153
154
155
156
157
def _with_sharding_constraint(x: Array, logical_axis_names: Shape):
    assert len(x.shape) == len(logical_axis_names)
    rules = extend_logical_axis_rules(tuple())
    rules_dict = {}
    for key, value in rules:
        rules_dict[key] = value

    mesh_axis_names = [rules_dict[name] for name in logical_axis_names]
    pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
    return with_sharding_constraint(x, pspec)


158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def _merge_mask(func, *masks: Optional[Array]):
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = func(mask, other_mask)
    return mask


def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
    """Combine attention masks."""
    func = jnp.logical_and
    return _merge_mask(func, *masks).astype(dtype)


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases."""
    func = lambda a, b: a + b
    return _merge_mask(func, *masks)


def core_attention(query: Array,
                   key: Array,
                   value: Array,
185
                   scale_factor: float,
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
                   transpose_batch_sequence: bool,
                   softmax_type: SoftmaxType = SoftmaxType.SCALED,
                   mask: Optional[Array] = None,
                   bias: Optional[Array] = None,
                   dropout_rng: Optional[PRNGKey] = None,
                   dropout_rate: float = 0.,
                   deterministic: bool = False,
                   dtype: DType = jnp.float32,
                   float32_logits: bool = False):
    """Core attention"""
    assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
    batch_dim = 1 if transpose_batch_sequence else 0
    assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
        'q, k, v batch dims must match.')
    sequence_dim = 0 if transpose_batch_sequence else 1
    assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
zlsh80826's avatar
zlsh80826 committed
202
203
    assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
    assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
204
205
206
207
208

    if float32_logits:
        query = query.astype(jnp.float32)
        key = key.astype(jnp.float32)

zlsh80826's avatar
zlsh80826 committed
209
    h_q, h_kv = query.shape[-2], key.shape[-2]
210
211
212
213
214
215
216
217
    # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
    # Therefore, we have to maintain two code paths.
    is_gqa = (h_q != h_kv)

    if is_gqa:
        assert (h_q % h_kv == 0) and (h_q >= h_kv)
        group_size = h_q // h_kv
        grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
zlsh80826's avatar
zlsh80826 committed
218

219
    if transpose_batch_sequence:
220
221
222
223
        if is_gqa:
            attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
        else:
            attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
224
    else:
225
226
227
228
        if is_gqa:
            attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
        else:
            attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
229

230
231
    attn_weights = checkpoint_name(attn_weights, 'logits')

232
233
234
235
    if is_gqa:
        b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
        attn_weights_without_groups_shape = (b, h * g, q, k)
        attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
zlsh80826's avatar
zlsh80826 committed
236

237
238
239
    attn_weights = _with_sharding_constraint(attn_weights,
                                             (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))

240
241
242
243
244
245
246
247
248
    # When a bias is present, the computation is performed as Softmax(attn_weights * scale + bias).
    # In this case, the scale can not fused into the Softmax module.
    if bias is not None:
        attn_weights = attn_weights * scale_factor
        fused_scale_factor = 1.
    else:
        # If no bias, the scale can be fused into Softmax module
        fused_scale_factor = scale_factor

249
    attn_weights = Softmax(softmax_type=softmax_type,
250
                           scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
251

252
253
    if is_gqa:
        attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
zlsh80826's avatar
zlsh80826 committed
254

255
256
257
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        dropout_shape = list(attn_weights.shape)
258
        # TODO(rewang): add attention dropout broadcast dimension arguments for users
259
260
261
262
263
        keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    if transpose_batch_sequence:
264
265
266
        if is_gqa:
            return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
        return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
267

268
269
270
    if is_gqa:
        return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
    return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
271
272
273
274
275


dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


276
class MultiHeadAttention(nn.Module):    # pylint: disable=too-few-public-methods
277
278
279
280
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

281
    .. note::
282

283
284
        Argument :attr:`mask` will be ignored when
        :attr:`attn_mask_type` is set to `"causal"`.
285

286
287
288
    Parameters
    ----------
    head_dim : int
289
        The hidden dimension of each attention head.
290
    num_heads : int
291
        The number of attention heads
zlsh80826's avatar
zlsh80826 committed
292
293
294
295
296
297
298
299
    num_gqa_groups : int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
300
    dropout_rate : float, default = 0.0
301
        Dropout probability for the dropout op during multi-head attention.
302
    dropout_rng_name: str, default = 'dropout'
303
304
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
305
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
306
        Indicate the type of layer normalization.
307
    layernorm_epsilon: float, default = 1e-6
308
        A value added to the denominator of layer normalization for numerical stability.
309
310
311
312
313
314
315
316
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
317
318
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
319
320
        Used for initializing the QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
321
    use_bias: bool, default = False
322
323
        Indicate whether or not to enable bias shifting for QKVO projections.
        If set to False, the layer will not learn additive biases.
324
    bias_init: Initializer, default = flax.linen.initializers.zeros
325
326
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
327
    apply_residual_connection_post_layernorm : bool, default = False
328
        Indicate if apply residual connection with the output of layer normalization.
329
    output_layernorm : bool, default = False
330
        Indicate if apply a layer normalization at the end of MHA.
331
332
    attn_mask_type: {'causal', 'padding'}, default = 'causal'
        Type of attention mask passed into softmax operation.
333
        Introduced in v0.10.0.
334
335
336
337

    Optimization parameters
    -----------------------
    dtype :jax.numpy.dtype, default  = jax.numpy.float32
338
        The data type used to allocate the initial parameters.
339
    fuse_qkv: bool, default = True
340
        If set to True, this module exposes a single fused
341
342
343
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
    transpose_batch_sequence : bool, default = True
344
        Indicate whether the input tensors were switched axis of batch
345
346
347
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
348
349
        Indicate whether to scale attention logits.
        If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
350
351
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
352
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
353
    float32_logits : bool, default = False
354
        Whether to compute attention logits in float32.
355
356
357
358
    """

    head_dim: int
    num_heads: int
zlsh80826's avatar
zlsh80826 committed
359
    num_gqa_groups: int | None = None
360
361
362
363
    dropout_rate: float = 0.
    dropout_rng_name: str = 'dropout'
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
364
    zero_centered_gamma: bool = False
365
366
367
368
369
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
370
    attn_mask_type: str = 'causal'
371
372
373
374
375
376
377
378
379
380
    dtype: DType = jnp.float32
    fuse_qkv: bool = True
    transpose_batch_sequence: bool = True
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False    # computes logits in float32 for stability.

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
zlsh80826's avatar
zlsh80826 committed
381
382
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_heads
383
384
385
386
387
388
389
390
391
392
393
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs_q: Array,
                 inputs_kv: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> Array:
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
        inputs_q : jax.numpy.ndarray
            Input tensor for query projection.
        inputs_kv : jax.numpy.ndarray
            Input tensor for key/value projection.
        mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
        bias : jax.numpy.ndarray, default = None
            A tensor used to shift self-attention softmax input.
        *
        decode : bool,default = False
            Indicate whether to prepare and use an autoregressive cache.
        deterministic : bool,default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
419
420

        def query_init(*args):
421
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

454
455
456
457
458
459
460
461
        # TODO(rewang): make it configurable for pre_scale_bias
        attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS

        def canonicalize_attn_mask_type(attn_mask_type):
            """
            Convert the string to AttnMaskType
            """
            if attn_mask_type == 'causal':
462
                return AttnMaskType.PADDING_CAUSAL_MASK
463
464
465
466
467
            if attn_mask_type == 'padding':
                return AttnMaskType.PADDING_MASK
            raise ValueError(f"Unsupported {attn_mask_type=}, "
                             "supported attn_mask_type = {'causal', 'padding'}")

468
        is_self_attn = (inputs_q is inputs_kv)
zlsh80826's avatar
zlsh80826 committed
469
470
        is_gqa = (self.num_heads != self.num_gqa_groups)
        is_qkvpack = (is_self_attn and not is_gqa)
471
        qkv_layout = QKVLayout.BS3HD if is_self_attn else QKVLayout.BSHD_BS2HD
472
        attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
473

474
475
        q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
        kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
476
        enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
477

478
        has_fused_attn_kernel = is_fused_attn_kernel_available(self.dtype, self.dtype, qkv_layout,
479
                                                               attn_bias_type, attn_mask_type,
zlsh80826's avatar
zlsh80826 committed
480
481
482
                                                               self.dropout_rate, self.num_heads,
                                                               self.num_gqa_groups, q_seqlen,
                                                               kv_seqlen, self.head_dim)
483

484
        use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
485
            has_fused_attn_kernel and \
486
            enable_fused_attn
487

488
        if enable_fused_attn and not use_fused_attn:
489
490
491
492
493
494
495
496
            reason = ""
            if decode:
                reason += f"decode=False is required but got {decode}, "
            if self.transpose_batch_sequence:
                reason += f"transpose_batch_sequence=False is required " \
                          f"but got {self.transpose_batch_sequence}, "
            if not self.fuse_qkv:
                reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
497
498
            if not has_fused_attn_kernel:
                reason += "no fused attention kernel is available, "
499
500

            warnings.warn(
501
502
503
                f"Fused attention is not enabled. Because " \
                f"{reason}fall back to unfused attention.")

504
505
        residual = inputs_q
        if self.fuse_qkv:
zlsh80826's avatar
zlsh80826 committed
506
            if is_qkvpack:
507
508
509
                qkv_proj, ln_out = LayerNormDenseGeneral(
                    enable_layernorm=not self.output_layernorm,
                    layernorm_type=self.layernorm_type,
510
                    zero_centered_gamma=self.zero_centered_gamma,
511
512
513
514
515
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
                    features=(3, self.num_heads * self.head_dim),
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    return_layernorm_output=self.apply_residual_connection_post_layernorm,
516
517
518
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
519
520
521
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
522
                    bias_axes=(W_JOINED_AXES, W_TP_AXES),
523
524
                    name='qkv',
                    dtype=self.dtype)(inputs_q)
525
                qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
526
527
                if not use_fused_attn:
                    query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
528
529
530
531
            else:
                query, ln_out = LayerNormDenseGeneral(
                    enable_layernorm=not self.output_layernorm,
                    layernorm_type=self.layernorm_type,
532
                    zero_centered_gamma=self.zero_centered_gamma,
533
534
535
536
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
                    features=self.num_heads * self.head_dim,
                    transpose_batch_sequence=self.transpose_batch_sequence,
zlsh80826's avatar
zlsh80826 committed
537
538
                    return_layernorm_output=(self.apply_residual_connection_post_layernorm
                                             or is_self_attn),
539
540
541
                    scale_axes=(W_NO_SHARD_AXES,),
                    ln_bias_axes=(W_NO_SHARD_AXES,),
                    kernel_axes=(W_FSDP_AXES, W_TP_AXES),
542
543
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
544
                    bias_axes=(W_TP_AXES,),
545
546
547
                    dtype=self.dtype,
                    kernel_init=query_init,
                    name='query')(inputs_q)
zlsh80826's avatar
zlsh80826 committed
548
549
550
551
552

                if is_self_attn:
                    assert ln_out is not None
                    inputs_kv = ln_out

553
                kv_proj = DenseGeneral(axis=-1,
zlsh80826's avatar
zlsh80826 committed
554
                                       features=(2, self.num_gqa_groups * self.head_dim),
555
                                       transpose_batch_sequence=self.transpose_batch_sequence,
556
                                       kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
557
558
559
                                       kernel_init=kv_init,
                                       use_bias=self.use_bias,
                                       bias_init=self.bias_init,
560
                                       bias_axes=(W_JOINED_AXES, W_TP_AXES),
561
562
                                       name='kv',
                                       dtype=self.dtype)(inputs_kv)
563
                kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
564
565
                if not use_fused_attn:
                    key, value = jnp.split(kv_proj, [1], axis=-2)
566
567
568
569
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
zlsh80826's avatar
zlsh80826 committed
570
                features=self.num_gqa_groups * self.head_dim,
571
                transpose_batch_sequence=self.transpose_batch_sequence,
572
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
573
574
                use_bias=self.use_bias,
                bias_init=self.bias_init,
575
                bias_axes=(W_TP_AXES,),
576
577
578
579
                dtype=self.dtype)
            query, ln_out = LayerNormDenseGeneral(
                enable_layernorm=not self.output_layernorm,
                layernorm_type=self.layernorm_type,
580
                zero_centered_gamma=self.zero_centered_gamma,
581
582
583
584
585
                epsilon=self.layernorm_epsilon,
                axis=-1,
                features=self.num_heads * self.head_dim,
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
586
587
588
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes=(W_FSDP_AXES, W_TP_AXES),
589
590
                use_bias=self.use_bias,
                bias_init=self.bias_init,
591
                bias_axes=(W_TP_AXES,),
592
593
594
595
                dtype=self.dtype,
                kernel_init=query_init,
                name='query')(inputs_q)

596
            if is_self_attn:
597
598
599
600
601
602
603
604
605
606
                assert ln_out is not None
                inputs_kv = ln_out

            key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

607
        if not use_fused_attn:
608
609
610
            query = checkpoint_name(query, 'query_proj')
            key = checkpoint_name(key, 'key_proj')
            value = checkpoint_name(value, 'value_proj')
zlsh80826's avatar
zlsh80826 committed
611
612
613
            query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
614
            qkv_sharding_constraint = \
615
                (SEQLEN_AXES, BATCH_AXES, HEAD_AXES, HIDDEN_AXES) \
616
                if self.transpose_batch_sequence \
617
618
619
620
                else (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
            query = _with_sharding_constraint(query, qkv_sharding_constraint)
            key = _with_sharding_constraint(key, qkv_sharding_constraint)
            value = _with_sharding_constraint(value, qkv_sharding_constraint)
621
622
623
624

        if decode:
            is_initialized = self.has_variable('cache', 'cached_key')

Ming-Xu Huang's avatar
Ming-Xu Huang committed
625
626
            cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
627
628
629
630
                                         value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
631
632
633
634
635
636
637
638
                if self.transpose_batch_sequence:
                    length, batch, num_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_heads, head_dim)
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
                    batch, length, num_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_heads, head_dim)
                    one_hot_indices_shape = (1, length, 1, 1)
639
640
641
642
643
644
645
646
647

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        f"expected query shape {expected_shape} instead got {query.shape}.")

                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
648
649
650
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
651
652
653
654
655
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
656
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
657
658
659
660
661

                if bias is not None:
                    bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
                                                       jnp.reshape(cur_index, (-1)), 1, -2)

662
663
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0

664
665
666
667
        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng(self.dropout_rng_name)

668
669
670
        if use_fused_attn:
            assert mask is not None and mask.ndim == 4    # (b, 1, s_q, s_kv)
            assert not self.transpose_batch_sequence
671

672
673
            seed = None
            if dropout_rng is not None:
674
                seed = jax.random.split(dropout_rng, num_of_devices())
675
676
677
                # ensure the old key never used
                del dropout_rng

zlsh80826's avatar
zlsh80826 committed
678
            if is_qkvpack:
679
                qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
680
681
682
                qkv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
                                           HIDDEN_AXES)
                qkv_proj = _with_sharding_constraint(qkv_proj, qkv_sharding_constraint)
683
684
685
                x = self_fused_attn(qkv_proj,
                                    bias,
                                    mask,
686
                                    seed,
687
                                    attn_bias_type=attn_bias_type,
688
                                    attn_mask_type=attn_mask_type,
689
690
                                    scaling_factor=scale_factor,
                                    dropout_probability=self.dropout_rate,
691
                                    is_training=not deterministic)
692
693
694
            else:
                assert bias is None
                query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
zlsh80826's avatar
zlsh80826 committed
695
                kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_gqa_groups, self.head_dim))
696
697
698
699
700
                q_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, HEAD_AXES, HIDDEN_AXES)
                kv_sharding_constraint = (BATCH_AXES, SEQLEN_AXES, JOINED_AXES, HEAD_AXES,
                                          HIDDEN_AXES)
                query = _with_sharding_constraint(query, q_sharding_constraint)
                kv_proj = _with_sharding_constraint(kv_proj, kv_sharding_constraint)
701
702
703

                x = cross_fused_attn(query,
                                     kv_proj,
704
                                     bias,
705
                                     mask,
706
                                     seed,
707
                                     attn_bias_type=attn_bias_type,
708
                                     attn_mask_type=attn_mask_type,
709
710
                                     scaling_factor=scale_factor,
                                     dropout_probability=self.dropout_rate,
711
                                     is_training=not deterministic)
712
        else:
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727

            def convert_to_softmax_type(attn_mask_type, mask):
                """
                Convert the string to SoftmaxType
                """
                if attn_mask_type == 'causal':
                    return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
                if attn_mask_type == 'padding':
                    if mask is not None:
                        return SoftmaxType.SCALED_MASKED
                    return SoftmaxType.SCALED
                raise ValueError(f"Unsupported {attn_mask_type=}, "
                                 "supported attn_mask_type = {'causal', 'padding'}")

            softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
728
729
730
731
732
733
734
735
736
737
738
739
740
741

            x = core_attention(query,
                               key,
                               value,
                               scale_factor=scale_factor,
                               transpose_batch_sequence=self.transpose_batch_sequence,
                               softmax_type=softmax_type,
                               mask=mask,
                               bias=bias,
                               dropout_rng=dropout_rng,
                               dropout_rate=self.dropout_rate,
                               deterministic=deterministic,
                               dtype=self.dtype,
                               float32_logits=self.float32_logits)
742

743
744
            x = checkpoint_name(x, 'context')

745
746
747
        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        attn_context_sharding_constraint = \
748
            (SEQLEN_AXES, BATCH_AXES, HIDDEN_TP_AXES) \
749
            if self.transpose_batch_sequence \
750
751
            else (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
        x = _with_sharding_constraint(x, attn_context_sharding_constraint)
752
753
754
755
756

        out = DenseGeneral(features=inputs_q.shape[-1],
                           transpose_batch_sequence=self.transpose_batch_sequence,
                           axis=-1,
                           kernel_init=self.kernel_init,
757
                           kernel_axes=(W_TP_AXES, W_FSDP_AXES),
758
759
                           use_bias=self.use_bias,
                           bias_init=self.bias_init,
760
                           bias_axes=(W_NO_SHARD_AXES,),
761
762
                           dtype=self.dtype,
                           name='out')(x)
763
        out = checkpoint_name(out, 'out_proj')
764
765
766
        return out, residual


767
class RelativePositionBiases(nn.Module):    # pylint: disable=too-few-public-methods
768
769
770
771
772
773
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
    num_buckets : int
774
        The number of buckets to bucket distances between key and query positions into.
775
    max_distance : int
776
        The maximum distance before everything is lumped into the last
777
778
        distance bucket.
    num_attention_heads : int
779
        Number of attention heads in the transformer layer.
780
    embedding_init : Initializer, default = flax.linen.linear.default_embed_init
781
        Used for initializing relative embedding tables.
782
    embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets')
783
        The name of axes used to shard embedding attention bias with a corresponding mesh.
784
785
786
787

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
788
        The data type used to allocate the initial parameters.
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
    """
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
    embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets')
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
        q_seqlen : int
805
            The sequence length of query.
806
        k_seqlen : int
807
            The sequence length of key.
808
        bidirectional : bool, default = True
809
            Indicate whether to allow positive memory-query relative position
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) /
            np.log(self.max_distance / rpb_max_exact) *
            (rpb_num_buckets - rpb_max_exact)).astype(np.int32)
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
            'rel_embedding',
            self.embedding_init, (self.num_attention_heads, self.num_buckets),
            jnp.float32,
            axes=self.embedding_axes)

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

        values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot,
                                 (((1,), (0,)), ((), ())))
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
860
861
862
863
864
865
866
867
868
869
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
870
871
872
873
    ENCODER = "encoder"
    DECODER = "decoder"


874
class TransformerLayer(nn.Module):    # pylint: disable=too-few-public-methods
875
876
877
878
879
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

880
881
882
883
884
    .. note::

        Argument :attr:`attention_mask` will be ignored when
        :attr:`self_attn_mask_type` is set to `"causal"`.

885
886
887
    Parameters
    ----------
    hidden_size: int, default = 512
888
        The hidden size of each input sample.
889
    mlp_hidden_size: int, default = 2048
890
        Intermediate size to which input samples are projected.
891
    num_attention_heads: int, default = 8
892
        Number of attention heads in the transformer layer.
zlsh80826's avatar
zlsh80826 committed
893
894
895
896
897
898
899
900
    num_gqa_groups : int, default = `None`
        Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
        Grouped Query Attention is described in
        `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
        This only affects the keys and values, not the querys.
        GQA-1 is equivalent to Multi-Query Attention
        (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
        is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
901
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
902
        Indicate the type of layer normalization.
903
    layernorm_epsilon: float, default = 1e-6
904
        A value added to the denominator of layer normalization for numerical stability.
905
906
907
908
909
910
911
912
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
913
    hidden_dropout: float, default = 0.1
914
        Dropout probability for the dropout op after FC2 layer.
915
    hidden_dropout_dims: Sequence[int], default = ()
916
        Dimensions that will share the same dropout mask for hidden
917
    attention_dropout: float, default = 0.1
918
        Dropout probability for the dropout op during multi-head attention.
919
920
921
922
    intermediate_dropout: float, default = 0.1
        Dropout probability for the dropout op after FC1 layer.
    intermediate_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden after FC1 layer.
923
    dropout_rng_name: str, default = 'dropout'
924
925
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
926
927
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
928
929
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
930
931
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
932
933
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
934
    mlp_activations: Sequence[str], default = ('relu', )
935
        The sequence of activation functions to apply after the first linear transformation.
936
937
        Each activation has its own transformation layer.
    use_bias: bool, default = False
938
939
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
940
    bias_init: Initializer, default = flax.linen.initializers.zeros
941
942
943
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
944
    apply_residual_connection_post_layernorm: bool, default = False
945
        If set to True, residual connections are taken from the output
946
947
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
948
        If set to True, layer normalization is applied on the output side,
949
950
951
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
952
        If set to True, attention logits are executed in jax.numpy.float32.
953
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
954
        If set to TransformerLayerType.DECODER, an additional cross-attention block
955
956
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
957
958
    self_attn_mask_type: {'causal', 'padding'}, default = 'causal'
        Type of attention mask passed into softmax operation.
959
        Introduced in v0.10.0.
960
    enable_relative_embedding: bool, default = True
961
        Whether to enable relative embedding as shifting of attention logits.
962
    relative_embedding: flax.linen.Module, default = None
963
        The module for relative embedding execution, only used when
964
965
966
967
968
969
970
971
972
973
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')

    Optimization parameters
    -----------------------
    dtype :jax.numpy.dtype, default  = jax.numpy.float32
974
        The data type used to allocate the initial parameters.
975
    drop_path: float, default = 0.0
976
        When > 0.0, applies stochastic depth per sample in the main
977
978
        path of the residual block.
    fuse_qkv_params: bool, default = True
979
        If set to True, `TransformerLayer` module exposes a single fused
980
981
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
982
    transpose_batch_sequence : bool, default = False
983
        Indicate whether the input tensors were switched axis of batch
984
985
986
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
987
        Indicate whether to scale attention logits.
988
989
990
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
991
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
992
993
994
995
996
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
zlsh80826's avatar
zlsh80826 committed
997
    num_gqa_groups: int | None = None
998
999
    layernorm_type: str = 'layernorm'
    layernorm_epsilon: float = 1e-6
1000
    zero_centered_gamma: bool = False
1001
1002
1003
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
1004
1005
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    dropout_rng_name: str = 'dropout'
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
    mlp_activations: Sequence[str] = ('relu',)
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
1016
    self_attn_mask_type: str = 'causal'
1017
1018
1019
1020
1021
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1022
    transpose_batch_sequence: bool = False
1023
1024
1025
1026
1027
1028
1029
1030
1031
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

    def __post_init__(self):
        if self.mha_kernel_init is None:
            self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
        if self.mlp_kernel_init is None:
            self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
                                                                    'truncated_normal')
zlsh80826's avatar
zlsh80826 committed
1032
1033
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs: Array,
                 encoded: Array = None,
                 attention_mask: Array = None,
                 encoder_decoder_mask: Array = None,
                 deterministic: bool = False,
                 decode: bool = False,
                 max_decode_length: bool = None):
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensor.
        encoded : jax.numpy.ndarray, default = None
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
        encoder_decoder_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
        deterministic: bool, default = False
1061
            Disable dropout layers if set to True.
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        decode: bool,default = False
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
        max_decode_length : bool, default = None
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
        outputs : jax.numpy.ndarray
1073
            Output tensors.
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        """
        assert self.layer_type in TransformerLayerType, \
                "layer_type should be one of TransformerLayerType" \
                f", but got {self.layer_type}."

        assert self.hidden_size % self.num_attention_heads == 0, \
                "hidden_size should be multiples of num_attention_heads" \
                f", but got {self.hidden_size=} and {self.num_attention_heads=}."

        assert self.layer_type == TransformerLayerType.DECODER or \
              (self.layer_type == TransformerLayerType.ENCODER and decode is False), \
               "decode should be False when layer_type == TransformerLayerType.ENCODER."

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
                rel_emb = RelativePositionBiases(num_buckets=32,
                                                 max_distance=128,
                                                 num_attention_heads=self.num_attention_heads,
                                                 dtype=self.dtype,
                                                 embedding_init=nn.initializers.variance_scaling(
                                                     1.0, 'fan_avg', 'uniform'),
                                                 name='relpos_bias')
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
            mha_name = 'attention'
        else:
            mha_name = 'self_attention'

1123
1124
        inputs = _with_sharding_constraint(inputs, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))

1125
1126
1127
1128
1129
        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        x, residual = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            dtype=self.dtype,
            head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1130
            num_gqa_groups=self.num_gqa_groups,
1131
1132
1133
1134
1135
1136
1137
1138
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
1139
            zero_centered_gamma=self.zero_centered_gamma,
1140
1141
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
1142
            attn_mask_type=self.self_attn_mask_type,
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
            fuse_qkv=self.fuse_qkv_params,
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            name=mha_name)(inputs,
                           inputs,
                           attention_mask,
                           attn_bias,
                           deterministic=deterministic,
                           decode=decode)

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1158
                assert -x_shape_len <= dims < x_shape_len
1159
1160

            return nn.Dropout(rate=self.hidden_dropout,
1161
1162
                              broadcast_dims=self.hidden_dropout_dims,
                              rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1163
1164
1165
1166
1167

        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
            x = nn.Dropout(rate=self.drop_path,
1168
1169
                           broadcast_dims=drop_path_shape,
                           rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
            assert encoded is not None, \
                "encoded is required when layer_type == TransformerLayerType.DECODER."

            y, residual = MultiHeadAttention(
                num_heads=self.num_attention_heads,
                dtype=self.dtype,
                head_dim=head_dim,
zlsh80826's avatar
zlsh80826 committed
1181
                num_gqa_groups=self.num_gqa_groups,
1182
1183
1184
1185
1186
                transpose_batch_sequence=self.transpose_batch_sequence,
                dropout_rate=self.attention_dropout,
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1187
                zero_centered_gamma=self.zero_centered_gamma,
1188
1189
1190
                apply_residual_connection_post_layernorm=self.
                apply_residual_connection_post_layernorm,
                output_layernorm=False,    # Must do LayerNorm before MHA.
1191
                attn_mask_type='padding',
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
                fuse_qkv=self.fuse_qkv_params,
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
                name='encoder_decoder_attention')(x,
                                                  encoded,
                                                  encoder_decoder_mask,
                                                  deterministic=deterministic)
            y = hidden_dropout(y, deterministic)
            mlp_input = y + residual

1206
1207
        mlp_input = _with_sharding_constraint(mlp_input, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))

1208
1209
1210
1211
        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
1212
            zero_centered_gamma=self.zero_centered_gamma,
1213
1214
1215
1216
1217
            epsilon=self.layernorm_epsilon,
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
1218
1219
1220
            intermediate_dropout_rng_name=self.dropout_rng_name,
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
1221
            dtype=self.dtype,
1222
1223
            scale_axes=(W_NO_SHARD_AXES,),
            ln_bias_axes=(W_NO_SHARD_AXES,),
1224
            kernel_init=self.mlp_kernel_init,
1225
1226
            kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
            kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
1227
1228
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1229
1230
            bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
            bias_axes_2=(W_NO_SHARD_AXES,),
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
            name='mlp',
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
            z = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
        z = z + residual

        if self.output_layernorm:
            z = LayerNorm(layernorm_type=self.layernorm_type,
1247
1248
                          zero_centered_gamma=self.zero_centered_gamma,
                          epsilon=self.layernorm_epsilon,
1249
1250
                          scale_axes=(W_NO_SHARD_AXES,),
                          bias_axes=(W_NO_SHARD_AXES,),
1251
1252
1253
1254
1255
                          transpose_batch_sequence=self.transpose_batch_sequence,
                          dtype=self.dtype,
                          name="output_layer_norm")(z)

        return z