module.py 51.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import numpy as np
12
import jax.numpy as jnp
13
14
15
16
17
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
18
from jax.ad_checkpoint import checkpoint_name
19

20
21
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
22
23
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
24
from ..layernorm_mlp import fused_layernorm_fp8_mlp, activation_lu
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
from ..cpp_extensions import is_softmax_kernel_available
28
29
30
31
32

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
33
34
35
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
36
37
38
39
40
41
42
43
44
45
46
47
48
49
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


50
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
51
52
53
54
55
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
56
57
58
    return nn.initializers.zeros


59
def _create_layernorm_parameters(
60
    layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, input_dtype, dtype
61
):
62
63
    scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
    scale = scale.astype(input_dtype)
64
65

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
66
    if layernorm_type == "layernorm":
67
68
        bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
        bias = bias.astype(input_dtype)
69
    else:
70
        assert layernorm_type == "rmsnorm"
71
72
73
74
75
76
77
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
78
    if fn_or_string == "linear":
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
93
94
95
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
96
97
98
99
100
101
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


102
103
104
105
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
106
    hidden_in_names = "ijklm"[: len(axis)]
107
    assert len(features) <= 5
108
109
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
110
111
112
113
114
115
116
117
118

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
119
120
121
122
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
123
124
125
126
127
128

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


129
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
130
131
    r"""
    Applies softmax over a mini-batch of inputs.
132
133
134
135
136
137
138
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
139
140
141
142

    Parameters
    ----------
    scale_factor : float, default = 1.0
143
144
145
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
146
147
148
149
150
151
152
153
154
155
156
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
157
        input_dtype = inputs.dtype
158
159
        logits = inputs

160
        if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
161
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
162
        ):
163
164

            if bias is not None:
165
                logits = logits + bias.astype(input_dtype)
166
167
168
169
170

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

171
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
172
173
174
        else:
            attention_bias = None
            if mask is not None:
175
176
                attention_bias = lax.select(
                    mask > 0,
177
178
                    jnp.full(mask.shape, -1e10),
                    jnp.full(mask.shape, 0.0),
179
                )
180
                attention_bias = attention_bias.astype(input_dtype)
181
182
183
184
185

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
186
                logits = logits + attention_bias.astype(input_dtype)
187
188
189

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
190
            if is_softmax_kernel_available(
191
                SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype
192
            ):
193
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
194
            else:
195
                outputs = jax_nn.softmax(logits * self.scale_factor)
196

197
        assert input_dtype == outputs.dtype
198
199
200
        return outputs


201
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
231
        A value added to the denominator of layer normalization for numerical stability.
232
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
233
        Indicate the type of layer normalization.
234
235
236
237
238
239
240
241
242
243
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
244
        Used for initializing scale factors :math:`\gamma`.
245
246
247
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
248
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
249
    scale_axes : Tuple[str, ...], default = ('embed', )
250
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
251
    bias_init : Initializer, default = flax.linen.initializers.zeros
252
253
254
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
255
256
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
257
        only used when :attr:`layernorm_type='layernorm'`.
258
259
260

    Optimization parameters
    -----------------------
261
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
262
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
263
    transpose_batch_sequence : bool, default = False
264
265
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
266
267
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
268

269
    epsilon: float = 1e-6
270
    layernorm_type: str = "layernorm"
271
272
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
273
    scale_axes: Tuple[str, ...] = ("embed",)
274
    bias_init: Initializer = nn.initializers.zeros
275
    bias_axes: Tuple[str, ...] = ("embed",)
276
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
    transpose_batch_sequence: bool = False
278

279
    def __post_init__(self):
280
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
281
282
            self.scale_init,
            self.zero_centered_gamma,
283
        )
284
285
        super().__post_init__()

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
301
        input_dtype = x.dtype
302
303

        features = x.shape[-1]
304
305
306
307
308
309
310
        scale, ln_bias = _create_layernorm_parameters(
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
311
            input_dtype,
312
313
            self.dtype,
        )
314
        out = layernorm(
315
316
317
318
319
320
321
            x,
            scale,
            ln_bias,
            layernorm_type=self.layernorm_type,
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
322
323
        assert out.dtype == input_dtype
        return out
324
325
326


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
327
328
329
330
331
    """
    Base class of transformer engine
    """

    @staticmethod
332
    def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage:
333
        """
334
        Generate a set of FP8 meta for a GEMM.
335
336
        """

337
338
339
340
341
        input_name_post_fix = f"_i_{postfix}"
        weight_name_post_fix = f"_w_{postfix}"
        grad_name_post_fix = f"_g_{postfix}"

        def generate_a_set(target_postfix):
342
343
344
345
346
347
348
349
            amax = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
                jnp.zeros,
                (FP8Helper.AMAX_HISTORY_LEN,),
                jnp.float32,
                axes=(None,),
            )
350
351
352
353

            scale = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
354
355
                jnp.ones,
                (1,),
356
                jnp.float32,
357
358
                axes=(None,),
            )
359
360
361
362
363
364

            return amax.value, scale.value

        input_amax, input_scale = generate_a_set(input_name_post_fix)
        weight_amax, weight_scale = generate_a_set(weight_name_post_fix)
        grad_amax, grad_scale = generate_a_set(grad_name_post_fix)
365

366
367
368
        return FP8MetaPackage(
            input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale
        )
369
370
371


class DenseGeneral(TransformerEngineBase):
372
373
    r"""
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`.
374
375
376
377

    Parameters
    ----------
    features : Union[Iterable[int], int]
378
        The hidden size of each output sample.
379
380
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
381
382
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
383
    kernel_axes : Tuple[str, ...], default = ()
384
        The name of axes used to shard the weights with a corresponding mesh.
385
    use_bias: bool, default = False
386
387
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
388
    bias_init: Initializer, default = flax.linen.initializers.zeros
389
390
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
391
    bias_axes: Tuple[str, ...], default = ()
392
393
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
394
395
396
397
398
399
400
401
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
402
    axis:  Union[Iterable[int], int], default = -1
403
        An integer tuple with axes to apply the transformation on.
404
405
406

    Optimization parameters
    -----------------------
407
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
408
        The data type used to allocate the initial parameters.
409
    transpose_batch_sequence : bool, default = True
410
411
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
412
413
414
415
416
417
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
418
    use_bias: bool = True
419
420
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
421
422
423
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
424
425
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
426
    transpose_batch_sequence: bool = False
427
428
429

    def __post_init__(self):
        if self.kernel_init is None:
430
            self.kernel_init = nn.initializers.variance_scaling(
431
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
432
            )
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
450

451
        input_dtype = inputs.dtype
452
453
454
455
456
457
458
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
459
        kernel = nn_partitioning.param_with_axes(
460
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
461
        )
462
463
        if not FP8Helper.is_fp8_enabled():
            kernel = kernel.astype(input_dtype)
464
465

        if self.use_bias:
466
            bias = nn_partitioning.param_with_axes(
467
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
468
            )
469
            bias = bias.astype(input_dtype)
470
471
472
473
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
474
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
475
        if FP8Helper.is_fp8_enabled():
476
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
477

478
479
480
        y = type_safe_dot_general(
            inputs, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
        )
481

482
        if self.enable_low_rank_adaptation:
483
484
485
486
487
488
489
490
491
492
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
493
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
494
495
496
497
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
498
                self.dtype,
499
500
                axes=lora_a_kernel_axes,
            )
501
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
502
            lora_a_kernel = lora_a_kernel.astype(input_dtype)
503
504
505

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
506
507
508
509
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
510
                self.dtype,
511
512
                axes=lora_b_kernel_axes,
            )
513
            lora_b_kernel = lora_b_kernel.astype(input_dtype)
514

515
516
517
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
518

519
        if bias is not None:
520
521
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
522
523

        assert y.dtype == input_dtype
524
525
526
527
528
529
530
531
532
533
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
534
        The hidden size of each output sample.
535
    enable_layernorm: bool, default = True
536
        Indicate whether to enable layer normalization before linear transformation.
537
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
538
        Indicate the type of layer normalization.
539
    epsilon : float, default = 1e-6
540
        A value added to the denominator of layer normalization for numerical stability.
541
542
543
544
545
546
547
548
549
550
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
551
        Used for initializing scale factors :math:`\gamma`.
552
553
554
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
555
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
556
    scale_axes : Tuple[str, ...], default = ('embed', )
557
558
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
559
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
560
561
562
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
563
564
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
565
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
566
567
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
568
569
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
570
    kernel_axes : Tuple[str, ...], default = ()
571
        The name of axes used to shard the weights with a corresponding mesh.
572
    use_bias: bool, default = False
573
574
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
575
    bias_init: Initializer, default = flax.linen.initializers.zeros
576
577
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
578
    bias_axes: Tuple[str, ...], default = ()
579
580
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
581
    return_layernorm_output: bool, default = True
582
        Indicate whether to return the output of layer normalization.
583
        If set False, return None as the second tensor in outputs.
584
585
586
587
588
589
590
591
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
592
    axis:  Union[Iterable[int], int], default = -1
593
        An integer tuple with axes to apply the transformation on.
594
595
596
597
598
599
600
601
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
602
603
604

    Optimization parameters
    -----------------------
605
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
606
        The data type used to allocate the initial parameters.
607
    transpose_batch_sequence : bool, default = True
608
609
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
610
611
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
612
        The factor to scale the output from `DenseGeneral`. It should be a float
613
614
615
616
617
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
618
    layernorm_type: str = "layernorm"
619
    epsilon: float = 1e-6
620
621
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
622
    scale_axes: Tuple[str, ...] = ("embed",)
623
    ln_bias_init: Initializer = nn.initializers.zeros
624
    ln_bias_axes: Tuple[str, ...] = ("embed",)
625
626
627
628
629
630
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
631
632
633
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
634
635
636
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
637
638
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
639
640
641
642
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
643
            self.kernel_init = nn.initializers.variance_scaling(
644
645
646
                1.0,
                "fan_in",
                "truncated_normal",
647
                dtype=self.dtype,
648
            )
649
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
650
651
            self.scale_init,
            self.zero_centered_gamma,
652
        )
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
671
            If :attr:`return_layernorm_output=False`, then this would be None.
672
        """
673

674
        input_dtype = inputs.dtype
675
676
        ln_output = None

677
678
679
680
681
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
682
683

        if self.enable_layernorm:
684
685
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)

686
            assert self.axis == -1  # Only support axis = =-1 at this moment
687
688
            features = inputs.shape[-1]

689
690
691
692
693
694
695
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
696
                input_dtype,
697
698
                self.dtype,
            )
699
700

            if not fuse_layernorm:
701
702
703
704
705
706
707
708
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
726
        kernel = nn_partitioning.param_with_axes(
727
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
728
        )
729
730
        if not FP8Helper.is_fp8_enabled():
            kernel = kernel.astype(input_dtype)
731
732
733

        contract_ind = tuple(range(0, len(axis)))

734
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
735
        if FP8Helper.is_fp8_enabled():
736
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
737
738

        if fuse_layernorm:
739
740
741
742
743
744
745
746
747
748
749
750
            z = layernorm_fp8_dot(
                y,
                kernel,
                scale,
                ln_bias,
                fp8_meta_pkg,
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
            )
751
        else:
752
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
753
754
755
            z = type_safe_dot_general(
                y, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
756

757
        if self.enable_low_rank_adaptation:
758
759
760
761
762
763
764
765
766
767
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
768
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
769
770
771
772
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
773
                self.dtype,
774
775
                axes=lora_a_kernel_axes,
            )
776
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
777
            lora_a_kernel = lora_a_kernel.astype(input_dtype)
778
779
780

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
781
782
783
784
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
785
                self.dtype,
786
787
                axes=lora_b_kernel_axes,
            )
788
            lora_b_kernel = lora_b_kernel.astype(input_dtype)
789

790
791
792
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
793

794
795
        bias = None
        if self.use_bias:
796
            bias = nn_partitioning.param_with_axes(
797
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
798
            )
799
            bias = bias.astype(input_dtype)
800
801

        if bias is not None:
802
803
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
804
805
806
807

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

808
        assert z.dtype == input_dtype
809
        return z, ln_output  # dense_output, layer_norm_output
810
811
812
813
814
815
816
817
818
819


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
820
        Intermediate size to which input samples are projected.
821
    enable_layernorm: bool, default = True
822
        Indicate whether to enable layer normalization before linear transformation.
823
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
824
        Indicate the type of layer normalization.
825
    epsilon : float, default = 1e-6
826
        A value added to the denominator of layer normalization for numerical stability.
827
828
829
830
831
832
833
834
835
836
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
837
        Used for initializing scale factors :math:`\gamma`.
838
839
840
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
841
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
842
    scale_axes : Tuple[str, ...], default = ('embed', )
843
844
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
845
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
846
847
848
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
849
850
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
851
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
852
853
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
854
855
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
856
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
857
        The name of axes used to shard the weights with a corresponding mesh for
858
859
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
860
        The name of axes used to shard the weights with a corresponding mesh for
861
862
        the weight of the second linear transformations.
    use_bias: bool, default = False
863
864
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
865
    bias_init: Initializer, default = flax.linen.initializers.zeros
866
867
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
868
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
869
        The name of axes used to shard bias with a corresponding mesh  for
870
        the weight of the first linear transformations.
871
        Only used when :attr:`use_bias=True`.
872
    bias_axes_2: Tuple[str, ...], default = ('embed',)
873
        The name of axes used to shard bias with a corresponding mesh  for
874
        the weight of the second linear transformations.
875
        Only used when :attr:`use_bias=True`.
876
    return_layernorm_output: bool, default = True
877
        Indicate whether to return the output of layer normalization.
878
879
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
880
        The sequence of activation functions to apply after the first linear transformation.
881
        Each activation has its own transformation layer.
882
883
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
884
    intermediate_dropout_rate: float, default = 0.1
885
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
886
887
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
888
889
890
891
892
893
894
895
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
896
    axis:  Union[Iterable[int], int], default = -1
897
        An integer tuple with axes to apply the transformation on.
898
899
900
901
902
903
904
905
906
907
908
909
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
910
911
912

    Optimization parameters
    -----------------------
913
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
914
        The data type used to allocate the initial parameters.
915
    transpose_batch_sequence : bool, default = True
916
917
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
918
919
920
921
922
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
923
    layernorm_type: str = "layernorm"
924
    epsilon: float = 1e-6
925
926
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
927
    scale_axes: Tuple[str, ...] = ("embed",)
928
    ln_bias_init: Initializer = nn.initializers.zeros
929
    ln_bias_axes: Tuple[str, ...] = ("embed",)
930
    kernel_init: Initializer = None
931
932
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
933
934
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
935
936
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
937
    return_layernorm_output: bool = True
938
939
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
940
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
941
    intermediate_hidden_dropout_dims: Sequence[int] = ()
942
943
944
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
945
946
947
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
948
949
950
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
951
952
953

    def __post_init__(self):
        if self.kernel_init is None:
954
            self.kernel_init = nn.initializers.variance_scaling(
955
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
956
            )
957
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
958
959
            self.scale_init,
            self.zero_centered_gamma,
960
        )
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
981
            If :attr:`return_layernorm_output=False`, then this would be None.
982
        """
983

984
        input_dtype = inputs.dtype
985
986
        ln_output = None

987
988
989
990
991
992
993
994
995
996
997
998
999
1000
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1001
        normalized_acts = []
1002
1003
1004
        for act in self.activations:
            if not isinstance(act, str):
                return False
1005
            normalized_acts.append(act.lower())
1006
        normalized_acts = tuple(
1007
1008
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1009

1010
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1011

1012
1013
1014
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1015

1016
1017
        # LayerNorm
        if self.enable_layernorm:
1018
            assert self.axis == -1  # Only support axis == -1 at this moment
1019
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1020

1021
1022
            features = inputs.shape[-1]

1023
1024
1025
1026
1027
1028
1029
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1030
                input_dtype,
1031
1032
                self.dtype,
            )
1033
1034

            if not fuse_layernorm:
1035
1036
1037
1038
1039
1040
1041
1042
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1057
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1058

1059
1060
        wi_fp8_meta_pkg = None
        wo_fp8_meta_pkg = None
1061
        if FP8Helper.is_fp8_enabled():
1062
1063
            wi_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
            wo_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("1")
1064

1065
        num_activations = len(normalized_acts)
1066
1067
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1068

1069
1070
1071
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
1072
1073
1074
1075
1076
1077
        kernel_1 = nn_partitioning.param_with_axes(
            "wi_kernel",
            kernel_1_init,
            num_activations,
            -2,
            kernel_1_each_shape,
1078
            self.dtype,
1079
1080
            axes=self.kernel_axes_1,
        )
1081
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
1082
1083
        if not FP8Helper.is_fp8_enabled():
            kernel_1 = kernel_1.astype(input_dtype)
1084
1085
1086
1087
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
1088
1089
1090
1091
        kernel_2 = nn_partitioning.param_with_axes(
            "wo_kernel",
            self.kernel_init,
            kernel_2_param_shape,
1092
            self.dtype,
1093
1094
            axes=self.kernel_axes_2,
        )
1095
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
1096
1097
        if not FP8Helper.is_fp8_enabled():
            kernel_2 = kernel_2.astype(input_dtype)
1098
        contract_ind = tuple(range(0, len(axis)))
1099

1100
1101
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1102

1103
        if use_fused_layernorm_mlp:
1104
            assert self.axis == -1  # Only support axis = =-1 at this moment
1105

1106
1107
            if self.use_bias:
                bias_1_shape = intermediate_dim
1108
                bias_1 = nn_partitioning.param_with_axes(
1109
1110
1111
                    "wi_bias",
                    self.bias_init,
                    bias_1_shape,
1112
                    self.dtype,
1113
                    axes=self.bias_axes_1,
1114
                )
1115
                bias_1 = bias_1.astype(input_dtype)
1116
1117

                bias_2_shape = (hidden_size,)
1118
                bias_2 = nn_partitioning.param_with_axes(
1119
1120
1121
                    "wo_bias",
                    self.bias_init,
                    bias_2_shape,
1122
                    self.dtype,
1123
                    axes=self.bias_axes_2,
1124
                )
1125
                bias_2 = bias_2.astype(input_dtype)
1126
            else:
1127
1128
                bias_1 = None
                bias_2 = None
1129

1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            out = fused_layernorm_fp8_mlp(
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                [wi_fp8_meta_pkg, wo_fp8_meta_pkg],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
                use_bias=self.use_bias,
            )

        else:  # not use_fused_ln_geglu_mlp
1150
1151
            # DenseGeneral 1
            if fuse_layernorm:
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
                x = layernorm_fp8_dot(
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
                    wi_fp8_meta_pkg,
                    self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
                )
1164
            else:
1165
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1166
1167
1168
                x = type_safe_dot_general(
                    y, kernel_1, fp8_meta_pkg=wi_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
                )
1169

1170
            if self.enable_low_rank_adaptation:
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
                wi_lora_a_kernel_shape = (
                    *kernel_1_shape[: len(axis)],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_shape = (
                    kernel_1_each_shape[0],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_each_shape = (
                    kernel_1_each_shape[0],
                    self.low_rank_adaptation_dim,
                )
1185
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
1186
1187
1188
1189
1190
1191
                wi_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_a_kernel",
                    kernel_1_init,
                    num_activations,
                    -2,
                    wi_lora_a_kernel_init_each_shape,
1192
                    self.dtype,
1193
1194
                    axes=wi_lora_a_kernel_axes,
                )
1195
                wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
1196
                wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
1197

1198
1199
1200
1201
1202
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1203
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1204
1205
1206
1207
                wi_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_b_kernel",
                    nn.initializers.zeros,
                    wi_lora_b_kernel_shape,
1208
                    self.dtype,
1209
1210
                    axes=wi_lora_b_kernel_axes,
                )
1211
                wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
1212

1213
1214
1215
1216
1217
1218
1219
1220
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
                    intermediate_dim,
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1221

1222
            bias_1 = None
1223
            if self.use_bias:
1224
                bias_1 = nn_partitioning.param_with_axes(
1225
1226
1227
                    "wi_bias",
                    self.bias_init,
                    intermediate_dim,
1228
                    self.dtype,
1229
                    axes=self.bias_axes_1,
1230
                )
1231
                bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
1232
                bias_1 = bias_1.astype(input_dtype)
1233
                x += jnp.reshape(bias_1, bias_1_shape)
1234

1235
            x = checkpoint_name(x, ffn1_ckpt_name)
1236
            if is_act_implemented:
1237
                z = activation_lu(x, normalized_acts)
1238
            else:
1239
                activations = []
1240
                x = jnp.split(x, num_activations, axis=-2)
1241
                for idx, act_fn in enumerate(normalized_acts):
1242
1243
1244
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
1245
1246
                # Remove act axis
                z = jnp.reshape(z, (*z.shape[:-2], -1))
1247
            z = z.astype(input_dtype)
1248

1249
1250
1251
1252
1253
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1254

1255
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1256
            z = z.astype(input_dtype)
1257

1258
            # DenseGeneral 2
1259
1260
1261
            out = type_safe_dot_general(
                z, kernel_2, fp8_meta_pkg=wo_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
1262

1263
1264
1265
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1266
1267
1268
1269
                wo_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_a_kernel",
                    self.kernel_init,
                    wo_lora_a_kernel_shape,
1270
                    self.dtype,
1271
1272
                    axes=wo_lora_a_kernel_axes,
                )
1273
                wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
1274
1275
1276

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1277
1278
1279
1280
                wo_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_b_kernel",
                    nn.initializers.zeros,
                    wo_lora_b_kernel_shape,
1281
                    self.dtype,
1282
1283
                    axes=wo_lora_b_kernel_axes,
                )
1284
                wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
1285

1286
1287
1288
1289
1290
1291
1292
1293
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1294

1295
            bias_2 = None
1296
            if self.use_bias:
1297
                bias_2 = nn_partitioning.param_with_axes(
1298
1299
1300
                    "wo_bias",
                    self.bias_init,
                    (hidden_size,),
1301
                    self.dtype,
1302
                    axes=self.bias_axes_2,
1303
                )
1304
                bias_2 = bias_2.astype(input_dtype)
1305
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1306

1307
            out = checkpoint_name(out, ffn2_ckpt_name)
1308

1309
        assert out.dtype == input_dtype
1310
        return out, ln_output  # Output, layner_norm_output