transformer.py 49.2 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
9
from math import sqrt
10
from typing import Any, Callable, Optional, Sequence, Tuple, Union
11
import warnings
12
13
14
15
16

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
17
from jax import dtypes
18
19
20
21
22
23
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
24
25
26
from ..fused_attn import AttnBiasType, AttnMaskType
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
27
28
29
from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, ShardingType
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
    """
50
    Extend the given Flax logical axis rules with the predefined TransformerLayer's
51
52
53
    logical axis rules.

    .. note::
Ming-Xu Huang's avatar
Ming-Xu Huang committed
54
55
56
57
        We currently only support logical axis rules for single GPU training, data parallel
        training and 1D-sharding tensor parallel training.
        Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
        for 1D-sharding tensor parallelism.
58
59
60
61

    .. warning::
        Please make sure ShardingResource is set via fp8_autocast before calling this function.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
62
63
64
65
    .. note::
        This function is only needed when using TransformerLayer. For  other modules, such as
        DenseGeneral, please properly set axes of kernels and bias.

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    Parameters
    ----------
    rules : Sequence[Tuple[str, Union[str, None]]]
        the base Flax logical axis rules to extend.

    Returns
    -------
    extended_rules : Sequence[Tuple[str, Union[str, None]]]
        the extended Flax logical axis rules.
    """
    rules_map = {}
    for item in rules:
        assert len(item) == 2, \
            "The logical axis rule should be like (axis_name, mesh_axis_name)."
        key = item[0]
        val = item[1]
        assert isinstance(key, str), \
            f"Thie axis_name should be str, but got {type(key)}."
        assert isinstance(val, str) or (val is None), \
            f"Thie mesh_axis_name should be str or None, but got {type(val)}."
Ming-Xu Huang's avatar
Ming-Xu Huang committed
86
87
88
89
        if key in rules_map:
            rules_map[key].append(val)
        else:
            rules_map[key] = [val]
90
91
92
93
94
95
96
97
98
99
100
101

    gsr = global_shard_resource()
    te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource),
                             ('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None),
                             ('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None),
                             ('relpos_buckets', None), ('length', None))

    extended_rules = [*rules]
    for item in te_logical_axis_rules:
        key = item[0]
        val = item[1]
        if key in rules_map:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
102
            assert len(rules_map[key]) == 1 and rules_map[key][0] == val, \
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                f"The rule diverged between TE and given rule." \
                f"Axis:{key} map to {rules_map[key]} in the given" \
                f" rules, but {val} in TE's rules."
        else:
            extended_rules.append(item)
    return tuple(extended_rules)


def _merge_mask(func, *masks: Optional[Array]):
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = func(mask, other_mask)
    return mask


def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
    """Combine attention masks."""
    func = jnp.logical_and
    return _merge_mask(func, *masks).astype(dtype)


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases."""
    func = lambda a, b: a + b
    return _merge_mask(func, *masks)


def core_attention(query: Array,
                   key: Array,
                   value: Array,
138
                   scale_factor: float,
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
                   transpose_batch_sequence: bool,
                   softmax_type: SoftmaxType = SoftmaxType.SCALED,
                   softmax_sharding_type: ShardingType = ShardingType.SINGLE,
                   mask: Optional[Array] = None,
                   bias: Optional[Array] = None,
                   dropout_rng: Optional[PRNGKey] = None,
                   dropout_rate: float = 0.,
                   deterministic: bool = False,
                   dtype: DType = jnp.float32,
                   float32_logits: bool = False):
    """Core attention"""
    assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
    batch_dim = 1 if transpose_batch_sequence else 0
    assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
        'q, k, v batch dims must match.')
    assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
    sequence_dim = 0 if transpose_batch_sequence else 1
    assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
    assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'

    if float32_logits:
        query = query.astype(jnp.float32)
        key = key.astype(jnp.float32)

    if transpose_batch_sequence:
        attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
    else:
        attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)

    attn_weights = Softmax(softmax_type=softmax_type,
169
                           scale_factor=scale_factor,
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
                           sharding_type=softmax_sharding_type)(attn_weights, mask, bias)

    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        dropout_shape = list(attn_weights.shape)
        dropout_shape[-2] = 1
        keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        keep = jnp.broadcast_to(keep, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    if transpose_batch_sequence:
        return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)

    return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)


dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


class AttentionType(Enum):
    """TransformerLayerType."""
192
193
    PADDING = AttnMaskType.PADDING_MASK
    CAUSAL = AttnMaskType.CAUSAL_MASK
194
195
196
197
198
199
200
201
202
203


class MultiHeadAttention(nn.Module):
    r"""
    Multi-head Attention (MHA), including Query,
    Key, Value and Output projection.

    Parameters
    ----------
    head_dim : int
204
        The hidden dimension of each attention head.
205
    num_heads : int
206
        The number of attention heads
207
    dropout_rate : float, default = 0.0
208
        Dropout probability for the dropout op during multi-head attention.
209
    dropout_rng_name: str, default = 'dropout'
210
211
        The key in given RNGs via flax.linen.Module.apply that is used
        to generate Dropout masks in the core attention.
212
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
213
        Indicate the type of layer normalization.
214
    layernorm_epsilon: float, default = 1e-6
215
        A value added to the denominator of layer normalization for numerical stability.
216
217
218
219
220
221
222
223
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
224
225
    kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
226
227
        Used for initializing the QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
228
    use_bias: bool, default = False
229
230
        Indicate whether or not to enable bias shifting for QKVO projections.
        If set to False, the layer will not learn additive biases.
231
    bias_init: Initializer, default = flax.linen.initializers.zeros
232
233
        Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
234
    apply_residual_connection_post_layernorm : bool, default = False
235
        Indicate if apply residual connection with the output of layer normalization.
236
    output_layernorm : bool, default = False
237
        Indicate if apply a layer normalization at the end of MHA.
238
    attn_type: AttentionType, defult = AttentionType.PADDING
239
        Indicate the format of the attention mask in the core attention.
240
241
242
243

    Optimization parameters
    -----------------------
    dtype :jax.numpy.dtype, default  = jax.numpy.float32
244
        The data type used to allocate the initial parameters.
245
    fuse_qkv: bool, default = True
246
        If set to True, this module exposes a single fused
247
248
249
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
    transpose_batch_sequence : bool, default = True
250
        Indicate whether the input tensors were switched axis of batch
251
252
253
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
254
255
        Indicate whether to scale attention logits.
        If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
256
257
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
258
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
259
    float32_logits : bool, default = False
260
        Whether to compute attention logits in float32.
261
262
263
264
265
266
267
268
    """

    head_dim: int
    num_heads: int
    dropout_rate: float = 0.
    dropout_rng_name: str = 'dropout'
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
269
    zero_centered_gamma: bool = False
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    kernel_init: Initializer = None
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    attn_type: AttentionType = AttentionType.PADDING
    dtype: DType = jnp.float32
    fuse_qkv: bool = True
    transpose_batch_sequence: bool = True
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False    # computes logits in float32 for stability.

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs_q: Array,
                 inputs_kv: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> Array:
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        """
        MultiHeadAttention Layer:
        [Query, Key, Value projection] -> Dot Product Attention -> Output projection.

        Parameters
        ----------
        inputs_q : jax.numpy.ndarray
            Input tensor for query projection.
        inputs_kv : jax.numpy.ndarray
            Input tensor for key/value projection.
        mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
        bias : jax.numpy.ndarray, default = None
            A tensor used to shift self-attention softmax input.
        *
        decode : bool,default = False
            Indicate whether to prepare and use an autoregressive cache.
        deterministic : bool,default = False
            Disable dropout layers if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
322
323

        def query_init(*args):
324
            depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)

        def qkv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 3

            q_key, k_key, v_key = jax_random.split(key, num=3)

            q_shape = (shape[0], shape[-1])
            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            q_kernel = query_init(q_key, q_shape, dtype)
            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)

        def kv_init(key, shape, dtype):
            assert len(shape) == 3
            assert shape[-2] == 2

            k_key, v_key = jax_random.split(key)

            k_shape = (shape[0], shape[-1])
            v_shape = (shape[0], shape[-1])

            k_kernel = self.kernel_init(k_key, k_shape, dtype)
            v_kernel = self.kernel_init(v_key, v_shape, dtype)

            return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)

        first_sharding_type, second_sharding_type = infer_sharding_type()

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
        q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
        kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
        fused_attn_supported_seqlen = [128, 256, 384, 512]
        use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
            self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
            q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
            and is_fused_attn_kernel_available()

        if not use_fused_attn:
            reason = ""
            if decode:
                reason += f"decode=False is required but got {decode}, "
            if self.transpose_batch_sequence:
                reason += f"transpose_batch_sequence=False is required " \
                          f"but got {self.transpose_batch_sequence}, "
            if not self.fuse_qkv:
                reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
            if self.dropout_rate != 0:
                # TODO(rewang): add dropout support
                reason += f"no dropout is required but got dropout_rate={self.dropout_rate}, "
            if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
                reason += f"dtype in [BF16, FP16] is required " \
                          f"but got dtype={canonicalize_dtype}, "
            if q_seqlen not in fused_attn_supported_seqlen:
                reason += f"q_seqlen in {fused_attn_supported_seqlen} is required " \
                          f"but got {q_seqlen=}, "
            if kv_seqlen not in fused_attn_supported_seqlen:
                reason += f"kv_seqlen in {fused_attn_supported_seqlen} is required " \
                          f"but got {kv_seqlen=}, "
            if not is_fused_attn_kernel_available():
                reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "

            warnings.warn(
                f"Fused attention is not enabled, " \
                f"{reason}fall back to unfused attention")

396
397
398
399
400
401
        residual = inputs_q
        if self.fuse_qkv:
            if inputs_q is inputs_kv:
                qkv_proj, ln_out = LayerNormDenseGeneral(
                    enable_layernorm=not self.output_layernorm,
                    layernorm_type=self.layernorm_type,
402
                    zero_centered_gamma=self.zero_centered_gamma,
403
404
405
406
407
408
409
410
411
412
413
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
                    features=(3, self.num_heads * self.head_dim),
                    sharding_type=first_sharding_type,
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    return_layernorm_output=self.apply_residual_connection_post_layernorm,
                    scale_axes=('embed',),
                    kernel_axes=('embed', 'qkv_dim', 'joined_kv'),
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
414
415
416
417
                    bias_axes=(
                        'qkv_dim',
                        'joined_kv',
                    ),
418
419
                    name='qkv',
                    dtype=self.dtype)(inputs_q)
420
421
                if not use_fused_attn:
                    query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
422
423
424
425
            else:
                query, ln_out = LayerNormDenseGeneral(
                    enable_layernorm=not self.output_layernorm,
                    layernorm_type=self.layernorm_type,
426
                    zero_centered_gamma=self.zero_centered_gamma,
427
428
429
430
431
432
433
434
435
436
                    epsilon=self.layernorm_epsilon,
                    axis=-1,
                    features=self.num_heads * self.head_dim,
                    sharding_type=first_sharding_type,
                    transpose_batch_sequence=self.transpose_batch_sequence,
                    return_layernorm_output=self.apply_residual_connection_post_layernorm,
                    scale_axes=('embed',),
                    kernel_axes=('embed', 'joined_kv'),
                    use_bias=self.use_bias,
                    bias_init=self.bias_init,
437
                    bias_axes=('joined_kv',),
438
439
440
441
442
443
444
445
446
447
448
                    dtype=self.dtype,
                    kernel_init=query_init,
                    name='query')(inputs_q)
                kv_proj = DenseGeneral(axis=-1,
                                       features=(2, self.num_heads * self.head_dim),
                                       sharding_type=first_sharding_type,
                                       transpose_batch_sequence=self.transpose_batch_sequence,
                                       kernel_axes=('embed', 'kv_dim', 'joined_kv'),
                                       kernel_init=kv_init,
                                       use_bias=self.use_bias,
                                       bias_init=self.bias_init,
449
450
451
452
                                       bias_axes=(
                                           'kv_dim',
                                           'joined_kv',
                                       ),
453
454
                                       name='kv',
                                       dtype=self.dtype)(inputs_kv)
455
456
                if not use_fused_attn:
                    key, value = jnp.split(kv_proj, [1], axis=-2)
457
458
459
460
461
462
463
464
465
466
        else:
            kv_projection = functools.partial(
                DenseGeneral,
                axis=-1,
                features=self.num_heads * self.head_dim,
                sharding_type=first_sharding_type,
                transpose_batch_sequence=self.transpose_batch_sequence,
                kernel_axes=('embed', 'joined_kv'),
                use_bias=self.use_bias,
                bias_init=self.bias_init,
467
                bias_axes=('joined_kv',),
468
469
470
471
                dtype=self.dtype)
            query, ln_out = LayerNormDenseGeneral(
                enable_layernorm=not self.output_layernorm,
                layernorm_type=self.layernorm_type,
472
                zero_centered_gamma=self.zero_centered_gamma,
473
474
475
476
477
478
479
480
481
482
                epsilon=self.layernorm_epsilon,
                axis=-1,
                features=self.num_heads * self.head_dim,
                sharding_type=first_sharding_type,
                transpose_batch_sequence=self.transpose_batch_sequence,
                return_layernorm_output=True,
                scale_axes=('embed',),
                kernel_axes=('embed', 'joined_kv'),
                use_bias=self.use_bias,
                bias_init=self.bias_init,
483
                bias_axes=('joined_kv',),
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
                dtype=self.dtype,
                kernel_init=query_init,
                name='query')(inputs_q)

            if inputs_q is inputs_kv:
                assert ln_out is not None
                inputs_kv = ln_out

            key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

499
500
501
502
503
504
505
506
507
508
509
        if not use_fused_attn:
            query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
            key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
            value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
            qkv_sharding_constraint = \
                ('length', 'batch', 'heads','kv') \
                if self.transpose_batch_sequence \
                else ('batch', 'length', 'heads', 'kv')
            query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint)
            key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint)
            value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint)
510
511
512
513

        if decode:
            is_initialized = self.has_variable('cache', 'cached_key')

Ming-Xu Huang's avatar
Ming-Xu Huang committed
514
515
            cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape,
516
517
518
519
                                         value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
520
521
522
523
524
525
526
527
                if self.transpose_batch_sequence:
                    length, batch, num_heads, head_dim = cached_key.value.shape
                    expected_shape = (1, batch, num_heads, head_dim)
                    one_hot_indices_shape = (length, 1, 1, 1)
                else:
                    batch, length, num_heads, head_dim = cached_key.value.shape
                    expected_shape = (batch, 1, num_heads, head_dim)
                    one_hot_indices_shape = (1, length, 1, 1)
528
529
530
531
532
533
534
535
536

                # Sanity shape check of cached key against input query.
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        f"expected query shape {expected_shape} instead got {query.shape}.")

                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
537
538
539
                one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
                key = cached_key.value + key * one_hot_indices
                value = cached_value.value + value * one_hot_indices
540
541
542
543
544
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1

                mask = combine_masks(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
545
                    mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)))
546
547
548
549
550

                if bias is not None:
                    bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
                                                       jnp.reshape(cur_index, (-1)), 1, -2)

551
552
        scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0

553
554
555
556
        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng(self.dropout_rng_name)

557
558
559
560
561
        if use_fused_attn:
            assert mask is not None and mask.ndim == 4    # (b, 1, s_q, s_kv)
            assert not self.transpose_batch_sequence
            # TODO(rewang): make it configurable for pre_scale_bias
            attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            if inputs_q is inputs_kv:
                qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
                qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv')
                qkv_proj = nn_partitioning.with_sharding_constraint(qkv_proj,
                                                                    qkv_sharding_constraint)
                x = self_fused_attn(qkv_proj,
                                    bias,
                                    mask,
                                    dropout_rng,
                                    attn_bias_type=attn_bias_type,
                                    attn_mask_type=self.attn_type.value,
                                    scaling_factor=scale_factor,
                                    dropout_probability=self.dropout_rate,
                                    is_training=not deterministic,
                                    sharding_type=first_sharding_type)
            else:
                assert bias is None
                query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
                kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim))
                q_sharding_constraint = ('batch', 'length', 'heads', 'kv')
                kv_sharding_constraint = ('batch', 'length', 'kv_dim', 'heads', 'kv')
                query = nn_partitioning.with_sharding_constraint(query, q_sharding_constraint)
                kv_proj = nn_partitioning.with_sharding_constraint(kv_proj, kv_sharding_constraint)

                x = cross_fused_attn(query,
                                     kv_proj,
                                     mask,
                                     dropout_rng,
                                     attn_bias_type=attn_bias_type,
                                     attn_mask_type=self.attn_type.value,
                                     scaling_factor=scale_factor,
                                     dropout_probability=self.dropout_rate,
                                     is_training=not deterministic,
                                     sharding_type=first_sharding_type)
        else:
            softmax_type = SoftmaxType.SCALED
            if self.attn_type is AttentionType.PADDING:
                if mask is not None:
                    softmax_type = SoftmaxType.SCALED_MASKED
            else:
                softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED

            x = core_attention(query,
                               key,
                               value,
                               scale_factor=scale_factor,
                               transpose_batch_sequence=self.transpose_batch_sequence,
                               softmax_type=softmax_type,
                               softmax_sharding_type=first_sharding_type,
                               mask=mask,
                               bias=bias,
                               dropout_rng=dropout_rng,
                               dropout_rate=self.dropout_rate,
                               deterministic=deterministic,
                               dtype=self.dtype,
                               float32_logits=self.float32_logits)
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635

        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        attn_context_sharding_constraint = \
            ('length', 'batch', 'joined_kv') \
            if self.transpose_batch_sequence \
            else ('batch', 'length', 'joined_kv')
        x = nn_partitioning.with_sharding_constraint(x, attn_context_sharding_constraint)

        out = DenseGeneral(features=inputs_q.shape[-1],
                           sharding_type=second_sharding_type,
                           transpose_batch_sequence=self.transpose_batch_sequence,
                           axis=-1,
                           kernel_init=self.kernel_init,
                           kernel_axes=('joined_kv', 'embed'),
                           use_bias=self.use_bias,
                           bias_init=self.bias_init,
636
                           bias_axes=('embed',),
637
638
639
640
641
642
643
644
645
646
647
648
                           dtype=self.dtype,
                           name='out')(x)
        return out, residual


class RelativePositionBiases(nn.Module):
    """
    T5-style relative positional embeddings to the attention logits.

    Parameters
    ----------
    num_buckets : int
649
        The number of buckets to bucket distances between key and query positions into.
650
    max_distance : int
651
        The maximum distance before everything is lumped into the last
652
653
        distance bucket.
    num_attention_heads : int
654
        Number of attention heads in the transformer layer.
655
    embedding_init : Initializer, default = flax.linen.linear.default_embed_init
656
        Used for initializing relative embedding tables.
657
    embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets')
658
        The name of axes used to shard embedding attention bias with a corresponding mesh.
659
660
661
662

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
663
        The data type used to allocate the initial parameters.
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    """
    num_buckets: int
    max_distance: int
    num_attention_heads: int
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init
    embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets')
    dtype: DType = jnp.float32

    @nn.compact
    def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
        """
        Generate relative position embedding attention biases.

        Parameters
        ----------
        q_seqlen : int
680
            The sequence length of query.
681
        k_seqlen : int
682
            The sequence length of key.
683
        bidirectional : bool, default = True
684
            Indicate whether to allow positive memory-query relative position
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
            embeddings.

        Returns
        -------
        output: jax.numpy.ndarray
            An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
        """
        context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = 0
        negative_rp = -relative_position
        rpb_num_buckets = self.num_buckets

        if bidirectional:
            rpb_num_buckets //= 2
            rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
            negative_rp = np.abs(negative_rp)
        else:
            negative_rp = np.maximum(negative_rp, 0)

        rpb_max_exact = rpb_num_buckets // 2
        rpb_is_small = negative_rp < rpb_max_exact
        rpb_val_if_large = rpb_max_exact + (
            np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) /
            np.log(self.max_distance / rpb_max_exact) *
            (rpb_num_buckets - rpb_max_exact)).astype(np.int32)
        rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
        rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)

        # Compute relative attention bias
        relative_attention_bias = nn_partitioning.param_with_axes(
            'rel_embedding',
            self.embedding_init, (self.num_attention_heads, self.num_buckets),
            jnp.float32,
            axes=self.embedding_axes)

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)

        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)

        values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot,
                                 (((1,), (0,)), ((), ())))
        return values[jnp.newaxis, ...]


class TransformerLayerType(Enum):
735
736
737
738
739
740
741
742
743
744
    r"""
    TransformerLayerType is an Enum class to specify a type of TransformerLayer

    Values
    ----------
    ENCODER:
        Encoder type of TransformerLayer.
    DECODER:
        Decoder type of TransformerLayer.
    """
745
746
747
748
749
750
751
752
753
754
755
756
757
    ENCODER = "encoder"
    DECODER = "decoder"


class TransformerLayer(nn.Module):
    r"""
    TransformerLayer is made up of a relative embedding,
    an attention block and a feedforward network (MLP).
    This standard layer is based on the paper “Attention Is All You Need”.

    Parameters
    ----------
    hidden_size: int, default = 512
758
        The hidden size of each input sample.
759
    mlp_hidden_size: int, default = 2048
760
        Intermediate size to which input samples are projected.
761
    num_attention_heads: int, default = 8
762
        Number of attention heads in the transformer layer.
763
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
764
        Indicate the type of layer normalization.
765
    layernorm_epsilon: float, default = 1e-6
766
        A value added to the denominator of layer normalization for numerical stability.
767
768
769
770
771
772
773
774
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
775
    hidden_dropout: float, default = 0.1
776
        Dropout probability for the dropout op after FC2 layer.
777
    hidden_dropout_dims: Sequence[int], default = ()
778
        Dimensions that will share the same dropout mask for hidden
779
    attention_dropout: float, default = 0.1
780
        Dropout probability for the dropout op during multi-head attention.
781
    dropout_rng_name: str, default = 'dropout'
782
783
        The key in given RNGs via flax.linen.Module.apply that for
        generating Dropout masks in the Multi-Head Attention.
784
785
    mha_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
786
787
        Used for initializing weights of QKV and Output projection weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
788
789
    mlp_kernel_init: Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
790
791
        Used for initializing weights of FC1 and FC2 layers.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
792
    mlp_activations: Sequence[str], default = ('relu', )
793
        The sequence of activation functions to apply after the first linear transformation.
794
795
        Each activation has its own transformation layer.
    use_bias: bool, default = False
796
797
        Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
        If set to False, the layer will not learn additive biases.
798
    bias_init: Initializer, default = flax.linen.initializers.zeros
799
800
801
        Used for initializing bias of QKVO projections,
        FC1 and FC2. It is only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
802
    apply_residual_connection_post_layernorm: bool, default = False
803
        If set to True, residual connections are taken from the output
804
805
        of layer norm (default is taken from input of layer norm)
    output_layernorm: bool, default = False
806
        If set to True, layer normalization is applied on the output side,
807
808
809
        after the final dropout-add. default behavior is to apply layer
        normalization on the input side, before the QKV transformation.
    float32_attention_logits: bool, default = False
810
        If set to True, attention logits are executed in jax.numpy.float32.
811
    layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
812
        If set to TransformerLayerType.DECODER, an additional cross-attention block
813
814
815
        is added after self-attention.this can be used for structures like `T5`
        Transformer in conjunction with the TransformerLayerType.ENCODER option.
    enable_relative_embedding: bool, default = True
816
        Whether to enable relative embedding as shifting of attention logits.
817
    relative_embedding: flax.linen.Module, default = None
818
        The module for relative embedding execution, only used when
819
820
821
822
823
824
825
826
827
828
        :attr:`enable_relative_embedding=True`. Default is None, which will create
        an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
        Default: RelativePositionBiases( num_buckets=32, max_distance=128,
        num_attention_heads=self.num_attention_heads, dtype=self.dtype,
        embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
        name='relpos_bias')

    Optimization parameters
    -----------------------
    dtype :jax.numpy.dtype, default  = jax.numpy.float32
829
        The data type used to allocate the initial parameters.
830
    drop_path: float, default = 0.0
831
        When > 0.0, applies stochastic depth per sample in the main
832
833
        path of the residual block.
    fuse_qkv_params: bool, default = True
834
        If set to True, `TransformerLayer` module exposes a single fused
835
836
        parameter for query-key-value for self-attention and key-value for
        cross-attention.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
837
    transpose_batch_sequence : bool, default = False
838
        Indicate whether the input tensors were switched axis of batch
839
840
841
        and sequence length dimension. if set to True, the input tensors
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    scale_attn_logits: bool, default = False
842
        Indicate whether to scale attention logits.
843
844
845
        if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
        else :math:`Q*K`
    scaled_query_init: bool, default = `True`
846
        Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
847
848
849
850
851
852
853
    """

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
    layernorm_type: str = 'layernorm'
    layernorm_epsilon: float = 1e-6
854
    zero_centered_gamma: bool = False
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    dropout_rng_name: str = 'dropout'
    mha_kernel_init: Initializer = None
    mlp_kernel_init: Initializer = None
    mlp_activations: Sequence[str] = ('relu',)
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
    enable_relative_embedding: bool = True
    relative_embedding: nn.Module = None
    dtype: DType = jnp.float32
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
873
    transpose_batch_sequence: bool = False
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

    def __post_init__(self):
        if self.mha_kernel_init is None:
            self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
        if self.mlp_kernel_init is None:
            self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
                                                                    'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs: Array,
                 encoded: Array = None,
                 attention_mask: Array = None,
                 encoder_decoder_mask: Array = None,
                 deterministic: bool = False,
                 decode: bool = False,
                 max_decode_length: bool = None):
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensor.
        encoded : jax.numpy.ndarray, default = None
            Output tensors of the encoder block to be fed into the decoder block if using
            :attr:`layer_type=TransformerLayerType.DECODER`.
        attention_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out self-attention softmax input.
        encoder_decoder_mask : jax.numpy.ndarray, default = None
            Boolean tensor used to mask out cross-attention softmax input when
            :attr:`layer_type=TransformerLayerType.DECODER`.
        deterministic: bool, default = False
910
            Disable dropout layers if set to True.
911
912
913
914
915
916
917
918
919
920
921
        decode: bool,default = False
            Indicate whether to prepare and use an autoregressive cache
            in Multi-head attention (MHA).
        max_decode_length : bool, default = None
            The maximum length to generate relative embedding biases when
            :attr:`layer_type=TransformerLayerType.DECODER` and
            :attr:`enable_relative_embedding=True`.

        Returns
        -------
        outputs : jax.numpy.ndarray
922
            Output tensors.
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        """
        assert self.layer_type in TransformerLayerType, \
                "layer_type should be one of TransformerLayerType" \
                f", but got {self.layer_type}."

        assert self.hidden_size % self.num_attention_heads == 0, \
                "hidden_size should be multiples of num_attention_heads" \
                f", but got {self.hidden_size=} and {self.num_attention_heads=}."

        assert self.layer_type == TransformerLayerType.DECODER or \
              (self.layer_type == TransformerLayerType.ENCODER and decode is False), \
               "decode should be False when layer_type == TransformerLayerType.ENCODER."

        head_dim = self.hidden_size // self.num_attention_heads

        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

        attn_bias = None
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
                rel_emb = RelativePositionBiases(num_buckets=32,
                                                 max_distance=128,
                                                 num_attention_heads=self.num_attention_heads,
                                                 dtype=self.dtype,
                                                 embedding_init=nn.initializers.variance_scaling(
                                                     1.0, 'fan_avg', 'uniform'),
                                                 name='relpos_bias')
            else:
                rel_emb = self.relative_embedding

            if self.layer_type == TransformerLayerType.ENCODER:
                attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
            else:
                if decode and max_decode_length:
                    l = max_decode_length
                else:
                    l = inputs.shape[sequence_dim]
                attn_bias = rel_emb(l, l, False)

        assert inputs.ndim == 3

        self_attn_type = None
        # Make name be the exactly same as T5X, since names would affect
        # RNGKey during init and apply. Myabe no need in the feature.
        if self.layer_type == TransformerLayerType.ENCODER:
            mha_name = 'attention'
            self_attn_type = AttentionType.PADDING
        else:
            mha_name = 'self_attention'
            self_attn_type = AttentionType.CAUSAL
        assert self_attn_type is not None

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        x, residual = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            dtype=self.dtype,
            head_dim=head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
989
            zero_centered_gamma=self.zero_centered_gamma,
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
            attn_type=self_attn_type,
            fuse_qkv=self.fuse_qkv_params,
            kernel_init=self.mha_kernel_init,
            use_bias=self.use_bias,
            bias_init=self.bias_init,
            name=mha_name)(inputs,
                           inputs,
                           attention_mask,
                           attn_bias,
                           deterministic=deterministic,
                           decode=decode)

        def hidden_dropout(x, deterministic):
            assert isinstance(self.hidden_dropout_dims, Sequence)
            x_shape_len = len(x.shape)
            for dims in self.hidden_dropout_dims:
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1008
                assert -x_shape_len <= dims < x_shape_len
1009
1010

            return nn.Dropout(rate=self.hidden_dropout,
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1011
1012
                              broadcast_dims=self.hidden_dropout_dims)(x,
                                                                       deterministic=deterministic)
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034

        x = hidden_dropout(x, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
            x = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
        x = x + residual

        mlp_input = x
        if self.layer_type == TransformerLayerType.DECODER:
            assert encoded is not None, \
                "encoded is required when layer_type == TransformerLayerType.DECODER."

            y, residual = MultiHeadAttention(
                num_heads=self.num_attention_heads,
                dtype=self.dtype,
                head_dim=head_dim,
                transpose_batch_sequence=self.transpose_batch_sequence,
                dropout_rate=self.attention_dropout,
                dropout_rng_name=self.dropout_rng_name,
                layernorm_type=self.layernorm_type,
                layernorm_epsilon=self.layernorm_epsilon,
1035
                zero_centered_gamma=self.zero_centered_gamma,
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
                apply_residual_connection_post_layernorm=self.
                apply_residual_connection_post_layernorm,
                output_layernorm=False,    # Must do LayerNorm before MHA.
                attn_type=AttentionType.PADDING,
                float32_logits=self.float32_attention_logits,
                scale_attn_logits=self.scale_attn_logits,
                scaled_query_init=self.scaled_query_init,
                fuse_qkv=self.fuse_qkv_params,
                kernel_init=self.mha_kernel_init,
                use_bias=self.use_bias,
                bias_init=self.bias_init,
                name='encoder_decoder_attention')(x,
                                                  encoded,
                                                  encoder_decoder_mask,
                                                  deterministic=deterministic)
            y = hidden_dropout(y, deterministic)
            mlp_input = y + residual

        # MlpBlock
        residual = mlp_input
        z, ln_out = LayerNormMLP(
            layernorm_type=self.layernorm_type,
1058
            zero_centered_gamma=self.zero_centered_gamma,
1059
1060
1061
1062
1063
1064
1065
            epsilon=self.layernorm_epsilon,
            major_sharding_type=infer_major_sharding_type(),
            transpose_batch_sequence=self.transpose_batch_sequence,
            return_layernorm_output=self.apply_residual_connection_post_layernorm,
            intermediate_dim=self.mlp_hidden_size,
            activations=self.mlp_activations,
            intermediate_dropout_rate=self.hidden_dropout,
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1066
            intermediate_hidden_dropout_dims=self.hidden_dropout_dims,
1067
1068
1069
1070
1071
1072
1073
            dtype=self.dtype,
            scale_axes=('embed',),
            kernel_init=self.mlp_kernel_init,
            kernel_axes_1=('embed', 'act', 'mlp'),
            kernel_axes_2=('mlp', 'embed'),
            use_bias=self.use_bias,
            bias_init=self.bias_init,
1074
1075
1076
1077
1078
            bias_axes_1=(
                'act',
                'mlp',
            ),
            bias_axes_2=('embed',),
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
            name='mlp',
        )(mlp_input, deterministic=deterministic)

        if self.apply_residual_connection_post_layernorm:
            assert ln_out is not None
            residual = ln_out

        z = hidden_dropout(z, deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
            z = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
        z = z + residual

        if self.output_layernorm:
            ln_sharding_type, _ = infer_sharding_type()
            z = LayerNorm(layernorm_type=self.layernorm_type,
1096
1097
                          zero_centered_gamma=self.zero_centered_gamma,
                          epsilon=self.layernorm_epsilon,
1098
1099
1100
1101
1102
1103
1104
1105
                          scale_axes=('embed',),
                          bias_axes=('embed',),
                          transpose_batch_sequence=self.transpose_batch_sequence,
                          dtype=self.dtype,
                          sharding_type=ln_sharding_type,
                          name="output_layer_norm")(z)

        return z