module.py 42.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random

from .dot import fp8_dot
from .fp8 import FP8GemmPackage, FP8Helper
from .layernorm import canonicalize_layernorm_type
from .layernorm import layernorm, layernorm_fp8_dot
from .mlp import fp8_ln_mlp, geglu
from .sharding import infer_sharding_type
from .softmax import is_softmax_kernel_available
from .sharding import MajorShardingType, ShardingType
from .softmax import softmax, SoftmaxType

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
                                 bias_axes, dtype):
    scale = nn_partitioning.param_with_axes('scale',
                                            scale_init,
                                            shape,
                                            jnp.float32,
                                            axes=scale_axes)
    scale = jnp.asarray(scale, dtype)

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'layernorm':
        bias = nn_partitioning.param_with_axes('ln_bias',
                                               bias_init,
                                               shape,
                                               jnp.float32,
                                               axes=bias_axes)
        bias = jnp.asarray(bias, dtype)
    else:
        assert layernorm_type == 'rmsnorm'
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


class Softmax(nn.Module):
    r"""
    Applies softmax over a mini-batch of inputs.
101
102
103
104
105
106
107
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
108
109
110
111

    Parameters
    ----------
    scale_factor : float, default = 1.0
112
113
114
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
115
116
117
118

    Optimization parameters
    -----------------------
    sharding_type : ShardingType, default = ShardingType.SINGLE
119
        Indicate the sharding pattern.
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED
    sharding_type: ShardingType = ShardingType.SINGLE

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

        if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
                self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)):

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type,
                              self.sharding_type)
        else:
            attention_bias = None
            if mask is not None:
                attention_bias = lax.select(mask > 0,
                                            jnp.full(mask.shape, -1e10).astype(dtype),
                                            jnp.full(mask.shape, 0.).astype(dtype))

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
            if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
                                           dtype):
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED,
                                  self.sharding_type)
            else:
167
                outputs = jax_nn.softmax(logits * self.scale_factor)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

        return outputs


class LayerNorm(nn.Module):
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
202
        A value added to the denominator of layer normalization for numerical stability.
203
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
204
        Indicate the type of layer normalization.
205
    scale_init : Initializer, default = flax.linen.initializers.ones
206
207
        Used for initializing scale factors :math:`\gamma`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
208
    scale_axes : Tuple[str, ...], default = ('embed', )
209
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
210
    bias_init : Initializer, default = flax.linen.initializers.zeros
211
212
213
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
214
215
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
216
        only used when :attr:`layernorm_type='layernorm'`.
217
218
219
220
221

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
        the data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
222
    transpose_batch_sequence : bool, default = False
223
224
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
225
226
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    sharding_type : ShardingType, default = ShardingType.SINGLE
227
        Indicate the sharding pattern.
228
229
230
231
232
233
234
235
    """
    epsilon: float = 1e-6
    layernorm_type: str = 'layernorm'
    scale_init: Initializer = nn.initializers.ones
    scale_axes: Tuple[str, ...] = ('embed',)
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ('embed',)
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
236
    transpose_batch_sequence: bool = False
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    sharding_type: ShardingType = ShardingType.SINGLE

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """

        features = x.shape[-1]
        scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                      self.scale_init, self.scale_axes,
                                                      self.bias_init, self.bias_axes, self.dtype)

        return layernorm(x,
                         scale,
                         ln_bias,
                         self.layernorm_type,
                         sharding_type=self.sharding_type,
                         dp_dim_index=1 if self.transpose_batch_sequence else 0,
                         epsilon=self.epsilon)


class TransformerEngineBase(nn.Module):
    """
    Base class of transformer engine
    """

    @staticmethod
    def get_fp8_metas(num_of_gemm: int) -> List[jnp.ndarray]:
        """
        Get the FP8 metas
        """
        num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
        axes = ('fp8_meta_axis', 'fp8_meta_history')

        fp8_max = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                     FP8Helper.FP8_MAX_NAME,
                                                     FP8Helper.generate_fp8_max_array,
                                                     num_of_meta,
                                                     axes=axes)
        fp8_metas_amax = nn_partitioning.variable_with_axes(
            FP8Helper.FP8_COLLECTION_NAME,
            FP8Helper.FP8_AMAX_NAME,
            jnp.zeros, (num_of_meta, FP8Helper.AMAX_HISTORY_LEN),
            jnp.float32,
            axes=axes)
        fp8_metas_scale = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                             FP8Helper.FP8_SCALE_NAME,
                                                             jnp.ones, (num_of_meta, 1),
                                                             jnp.float32,
                                                             axes=axes)
        fp8_metas_scale_inv = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                                 FP8Helper.FP8_SCALE_INV_NAME,
                                                                 jnp.ones, (num_of_meta, 1),
                                                                 jnp.float32,
                                                                 axes=axes)

        return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value

    @staticmethod
    def get_fp8_gemm_package(num_of_gemm: int, inputs: jnp.ndarray,
                             kernels: List[jnp.ndarray]) -> FP8GemmPackage:
        """
        Get the FP8 metas
        """
        assert num_of_gemm == len(kernels)
        fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
            TransformerEngineBase.get_fp8_metas(num_of_gemm)

        return FP8GemmPackage(num_of_gemm, inputs, kernels, fp8_max, fp8_metas_amax,
                              fp8_metas_scale, fp8_metas_scale_inv)


class DenseGeneral(TransformerEngineBase):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    Parameters
    ----------
    features : Union[Iterable[int], int]
327
        The hidden size of each output sample.
328
329
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
330
331
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
332
    kernel_axes : Tuple[str, ...], default = ()
333
        The name of axes used to shard the weights with a corresponding mesh.
334
    use_bias: bool, default = False
335
336
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
337
    bias_init: Initializer, default = flax.linen.initializers.zeros
338
339
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
340
    bias_axes: Tuple[str, ...], default = ()
341
342
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
343
    axis:  Union[Iterable[int], int], default = -1
344
        An integer tuple with axes to apply the transformation on.
345
346
347
348

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
349
        The data type used to allocate the initial parameters.
350
    transpose_batch_sequence : bool, default = True
351
352
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
353
354
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    sharding_type : ShardingType, default = ShardingType.SINGLE
355
        Indicate the sharding pattern.
356
357
358
359
360
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
361
    use_bias: bool = True
362
363
364
365
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
366
    transpose_batch_sequence: bool = False
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    sharding_type: ShardingType = ShardingType.SINGLE

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
                                                   self.bias_init, (self.features,),
                                                   self.dtype,
                                                   axes=self.bias_axes)
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))

        if FP8Helper.enable_fp8():
            fp8_gemm_package = \
                TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel])
            y = fp8_dot(fp8_gemm_package,
                        FP8Helper.FWD_DTYPE,
                        FP8Helper.BWD_DTYPE, (axis, contract_ind),
                        sharding_type=self.sharding_type,
                        dp_dim_index=1 if self.transpose_batch_sequence else 0)
        else:
            kernel = jnp.asarray(kernel, self.dtype)
            y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))

        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
439
        The hidden size of each output sample.
440
    enable_layernorm: bool, default = True
441
        Indicate whether to enable layer normalization before linear transformation.
442
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
443
        Indicate the type of layer normalization.
444
    epsilon : float, default = 1e-6
445
        A value added to the denominator of layer normalization for numerical stability.
446
    scale_init : Initializer, default = flax.linen.initializers.ones
447
448
        Used for initializing scale factors :math:`\gamma`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
449
    scale_axes : Tuple[str, ...], default = ('embed', )
450
451
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
452
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
453
454
455
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
456
457
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
458
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
459
460
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
461
462
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
463
    kernel_axes : Tuple[str, ...], default = ()
464
        The name of axes used to shard the weights with a corresponding mesh.
465
    use_bias: bool, default = False
466
467
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
468
    bias_init: Initializer, default = flax.linen.initializers.zeros
469
470
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
471
    bias_axes: Tuple[str, ...], default = ()
472
473
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
474
    return_layernorm_output: bool, default = True
475
        Indicate whether to return the output of layer normalization.
476
477
        If set False, return None as the second tensor in outputs.
    axis:  Union[Iterable[int], int], default = -1
478
        An integer tuple with axes to apply the transformation on.
479
480
481
482

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
483
        The data type used to allocate the initial parameters.
484
    transpose_batch_sequence : bool, default = True
485
486
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
487
488
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
489
        The factor to scale the output from `DenseGeneral`. It should be a float
490
491
        value or None. When None is set, then no scaling is applied.
    sharding_type : ShardingType, default = ShardingType.SINGLE
492
        Indicate the sharding pattern.
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
    scale_init: Initializer = nn.initializers.ones
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
    depth_scaling: float = None
    sharding_type: ShardingType = ShardingType.SINGLE

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
536
            If :attr:`return_layernorm_output=False`, then this would be None.
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        """
        ln_output = None

        fuse_layernorm = FP8Helper.enable_fp8(
        ) and not self.return_layernorm_output and self.enable_layernorm

        if self.enable_layernorm:
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
                              sharding_type=self.sharding_type,
                              dp_dim_index=1 if self.transpose_batch_sequence else 0,
                              epsilon=self.epsilon)
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        contract_ind = tuple(range(0, len(axis)))

        if FP8Helper.enable_fp8():
            fp8_gemm_package = \
                    TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel])

            if not fuse_layernorm:
                z = fp8_dot(fp8_gemm_package,
                            FP8Helper.FWD_DTYPE,
                            FP8Helper.BWD_DTYPE, (axis, contract_ind),
                            sharding_type=self.sharding_type,
                            dp_dim_index=1 if self.transpose_batch_sequence else 0)
            else:
                z = layernorm_fp8_dot(fp8_gemm_package,
                                      scale,
                                      ln_bias,
                                      self.layernorm_type,
                                      FP8Helper.FWD_DTYPE,
                                      FP8Helper.BWD_DTYPE, (axis, contract_ind),
                                      sharding_type=self.sharding_type,
                                      dp_dim_index=1 if self.transpose_batch_sequence else 0,
                                      epsilon=self.epsilon)
        else:
            kernel = jnp.asarray(kernel, self.dtype)
            z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))

        bias = None
        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
                                                   self.bias_init, (self.features,),
                                                   self.dtype,
                                                   axes=self.bias_axes)

        if bias is not None:
            z += jnp.reshape(bias, (1,) * (z.ndim - 1) + (-1,))

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

        return z, ln_output    # dense_output, layer_norm_output


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
634
        Intermediate size to which input samples are projected.
635
    enable_layernorm: bool, default = True
636
        Indicate whether to enable layer normalization before linear transformation.
637
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
638
        Indicate the type of layer normalization.
639
    epsilon : float, default = 1e-6
640
        A value added to the denominator of layer normalization for numerical stability.
641
    scale_init : Initializer, default = flax.linen.initializers.ones
642
643
        Used for initializing scale factors :math:`\gamma`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
644
    scale_axes : Tuple[str, ...], default = ('embed', )
645
646
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
647
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
648
649
650
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
651
652
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
653
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
654
655
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
656
657
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
658
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
659
        The name of axes used to shard the weights with a corresponding mesh for
660
661
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
662
        The name of axes used to shard the weights with a corresponding mesh for
663
664
        the weight of the second linear transformations.
    use_bias: bool, default = False
665
666
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
667
    bias_init: Initializer, default = flax.linen.initializers.zeros
668
669
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
670
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
671
        The name of axes used to shard bias with a corresponding mesh  for
672
        the weight of the first linear transformations.
673
        Only used when :attr:`use_bias=True`.
674
    bias_axes_2: Tuple[str, ...], default = ('embed',)
675
        The name of axes used to shard bias with a corresponding mesh  for
676
        the weight of the second linear transformations.
677
        Only used when :attr:`use_bias=True`.
678
    return_layernorm_output: bool, default = True
679
        Indicate whether to return the output of layer normalization.
680
681
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
682
        The sequence of activation functions to apply after the first linear transformation.
683
684
        Each activation has its own transformation layer.
    intermediate_dropout_rate: float, default = 0.1
685
        Dropout probability for the dropout op after the :attr:`activations`.
686
    axis:  Union[Iterable[int], int], default = -1
687
        An integer tuple with axes to apply the transformation on.
688
689
690
691

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
692
        The data type used to allocate the initial parameters.
693
    transpose_batch_sequence : bool, default = True
694
695
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
696
697
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE
698
        Indicate the sharding pattern.
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
    scale_init: Initializer = nn.initializers.ones
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp')
    kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes_1: Tuple[str, ...] = ('mlp',)
    bias_axes_2: Tuple[str, ...] = ('embed',)
    return_layernorm_output: bool = True
    activations: Sequence[Union[str, Callable]] = ('relu',)
    intermediate_dropout_rate: float = 0.1
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
    major_sharding_type: MajorShardingType = MajorShardingType.SINGLE

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
747
            If :attr:`return_layernorm_output=False`, then this would be None.
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
        """
        ln_output = None

        fuse_layernorm = FP8Helper.enable_fp8(
        ) and not self.return_layernorm_output and self.enable_layernorm

        use_fused_ln_mlp = fuse_layernorm \
            and (not self.use_bias) and self.activations == ('gelu', 'linear') \
                and (self.intermediate_dropout_rate < 1e-3)

        first_sharding_type, second_sharding_type = infer_sharding_type(self.major_sharding_type)

        # LayerNorm
        if self.enable_layernorm:
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
                              sharding_type=first_sharding_type,
                              dp_dim_index=1 if self.transpose_batch_sequence else 0,
                              epsilon=self.epsilon)
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
            return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)

        num_of_gemm = 2
        if use_fused_ln_mlp:
            num_activations = len(self.activations)
            axis = _canonicalize_tuple(self.axis)
            axis = _normalize_axes(axis, inputs.ndim)

            intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
            kernel_1_shape = tuple(inputs.shape[ax] for ax in axis) + intermediate_dim
            kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
            kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
                                                       kernel_1_init,
                                                       num_activations,
                                                       -2,
                                                       kernel_1_each_shape,
                                                       jnp.float32,
                                                       axes=self.kernel_axes_1)
            kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
            hidden_size = inputs.shape[-1]
            hidden_size_tuple = _canonicalize_tuple(hidden_size)
            kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
            kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
            kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
                                                       self.kernel_init,
                                                       kernel_2_param_shape,
                                                       jnp.float32,
                                                       axes=self.kernel_axes_2)
            kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
            contract_ind = tuple(range(0, len(axis)))

            fp8_gemm_package = \
                TransformerEngineBase.get_fp8_gemm_package(num_of_gemm, y, [kernel_1, kernel_2])
            out = fp8_ln_mlp(fp8_gemm_package,
                             scale,
                             ln_bias,
                             self.layernorm_type,
                             FP8Helper.FWD_DTYPE,
                             FP8Helper.BWD_DTYPE,
                             epsilon=self.epsilon,
                             contracting_dims=(axis, contract_ind),
                             major_sharding_type=self.major_sharding_type,
                             dp_dim_index=1 if self.transpose_batch_sequence else 0,
                             activations=self.activations)
        else:    # not use_fused_ln_mlp

            def fp8_meta_generator():
                fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None,
                                                                                 None)
                if FP8Helper.enable_fp8():
                    fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
                        TransformerEngineBase.get_fp8_metas(num_of_gemm)
                return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv

            fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
                fp8_meta_generator()

            # DenseGeneral 1
            activations = []
            num_activations = len(self.activations)
            axis = _canonicalize_tuple(self.axis)
            axis = _normalize_axes(axis, y.ndim)

            intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
            kernel_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
            kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
            kernel = nn_partitioning.param_with_axes('wi_kernel',
                                                     kernel_1_init,
                                                     num_activations,
                                                     -2,
                                                     kernel_1_each_shape,
                                                     jnp.float32,
                                                     axes=self.kernel_axes_1)
            kernel = jnp.reshape(kernel, kernel_shape)
            contract_ind = tuple(range(0, len(axis)))

            if FP8Helper.enable_fp8():
                fp8_gemm_package = FP8GemmPackage(
                    1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :],
                    fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :],
                    fp8_metas_scale[:FP8Helper.NUM_META_PER_GEMM, :],
                    fp8_metas_scale_inv[:FP8Helper.NUM_META_PER_GEMM, :])

                if not fuse_layernorm:
                    x = fp8_dot(fp8_gemm_package,
                                FP8Helper.FWD_DTYPE,
                                FP8Helper.BWD_DTYPE, (axis, contract_ind),
                                sharding_type=first_sharding_type,
                                dp_dim_index=1 if self.transpose_batch_sequence else 0)
                else:
                    x = layernorm_fp8_dot(fp8_gemm_package,
                                          scale,
                                          ln_bias,
                                          self.layernorm_type,
                                          FP8Helper.FWD_DTYPE,
                                          FP8Helper.BWD_DTYPE, (axis, contract_ind),
                                          sharding_type=first_sharding_type,
                                          dp_dim_index=1 if self.transpose_batch_sequence else 0,
                                          epsilon=self.epsilon)
            else:    # not enable fp8
                kernel = jnp.asarray(kernel, self.dtype)
                x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wi_bias',
                                                       self.bias_init, (self.intermediate_dim,),
                                                       self.dtype,
                                                       axes=self.bias_axes_1)
                x += jnp.reshape(bias, (1,) * (x.ndim - 1) + (-1,))

            if self.activations == ('gelu', 'linear'):
                z = geglu(x,
                          contracting_dims=(-2, -1),
                          sharding_type=second_sharding_type,
                          dp_dim_index=1 if self.transpose_batch_sequence else 0)
            else:
                x = jnp.split(x, num_activations, axis=-2)
                for idx, act_fn in enumerate(self.activations):
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
                z = jnp.reshape(z, (*z.shape[:-2], -1))

            z = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
                z, deterministic=deterministic)    # Broadcast along length.

            # DenseGeneral 2
            hidden_size = inputs.shape[-1]
            hidden_size_tuple = _canonicalize_tuple(hidden_size)
            axis = _canonicalize_tuple(self.axis)
            axis = _normalize_axes(axis, z.ndim)

            kernel_shape = tuple(z.shape[ax] for ax in axis) + hidden_size_tuple
            kernel_param_shape = (np.prod([z.shape[ax] for ax in axis]), np.prod(hidden_size_tuple))
            kernel = nn_partitioning.param_with_axes('wo_kernel',
                                                     self.kernel_init,
                                                     kernel_param_shape,
                                                     jnp.float32,
                                                     axes=self.kernel_axes_2)
            kernel = jnp.reshape(kernel, kernel_shape)

            contract_ind = tuple(range(0, len(axis)))

            if FP8Helper.enable_fp8():
                fp8_gemm_package = FP8GemmPackage(
                    1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :],
                    fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :],
                    fp8_metas_scale[FP8Helper.NUM_META_PER_GEMM:, :],
                    fp8_metas_scale_inv[FP8Helper.NUM_META_PER_GEMM:, :])

                out = fp8_dot(fp8_gemm_package,
                              FP8Helper.FWD_DTYPE,
                              FP8Helper.BWD_DTYPE, (axis, contract_ind),
                              sharding_type=second_sharding_type,
                              dp_dim_index=1 if self.transpose_batch_sequence else 0)
            else:
                kernel = jnp.asarray(kernel, self.dtype)
                out = lax.dot_general(z, kernel, ((axis, contract_ind), ((), ())))

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wo_bias',
                                                       self.bias_init, (hidden_size,),
                                                       self.dtype,
                                                       axes=self.bias_axes_2)
                out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))

        return out, ln_output    # Output, layner_norm_output