module.py 50.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
18
from jax.ad_checkpoint import checkpoint_name
19

20
21
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
22
23
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
24
from ..layernorm_mlp import fused_layernorm_fp8_mlp, activation_lu
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
from ..cpp_extensions import is_softmax_kernel_available
28
29
30
31
32

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
33
34
35
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
36
37
38
39
40
41
42
43
44
45
46
47
48
49
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


50
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
51
52
53
54
55
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
56
57
58
    return nn.initializers.zeros


59
60
61
def _create_layernorm_parameters(
    layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype
):
62
63
    scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
    scale = scale.astype(dtype)
64
65

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
66
    if layernorm_type == "layernorm":
67
68
        bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
        bias = bias.astype(dtype)
69
    else:
70
        assert layernorm_type == "rmsnorm"
71
72
73
74
75
76
77
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
78
    if fn_or_string == "linear":
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
93
94
95
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
96
97
98
99
100
101
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


102
103
104
105
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
106
    hidden_in_names = "ijklm"[: len(axis)]
107
    assert len(features) <= 5
108
109
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
110
111
112
113
114
115
116
117
118

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
119
120
121
122
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
123
124
125
126
127
128

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


129
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
130
131
    r"""
    Applies softmax over a mini-batch of inputs.
132
133
134
135
136
137
138
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
139
140
141
142

    Parameters
    ----------
    scale_factor : float, default = 1.0
143
144
145
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

160
161
162
        if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype
        ):
163
164
165
166
167
168
169
170

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

171
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
172
173
174
        else:
            attention_bias = None
            if mask is not None:
175
176
177
178
179
                attention_bias = lax.select(
                    mask > 0,
                    jnp.full(mask.shape, -1e10).astype(dtype),
                    jnp.full(mask.shape, 0.0).astype(dtype),
                )
180
181
182
183
184
185
186
187
188

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
189
190
191
            if is_softmax_kernel_available(
                SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype
            ):
192
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
193
            else:
194
                outputs = jax_nn.softmax(logits * self.scale_factor)
195
196
197
198

        return outputs


199
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
229
        A value added to the denominator of layer normalization for numerical stability.
230
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
231
        Indicate the type of layer normalization.
232
233
234
235
236
237
238
239
240
241
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
242
        Used for initializing scale factors :math:`\gamma`.
243
244
245
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
246
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
247
    scale_axes : Tuple[str, ...], default = ('embed', )
248
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
249
    bias_init : Initializer, default = flax.linen.initializers.zeros
250
251
252
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
253
254
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
255
        only used when :attr:`layernorm_type='layernorm'`.
256
257
258
259
260

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
        the data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
261
    transpose_batch_sequence : bool, default = False
262
263
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
264
265
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
266

267
    epsilon: float = 1e-6
268
    layernorm_type: str = "layernorm"
269
270
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
271
    scale_axes: Tuple[str, ...] = ("embed",)
272
    bias_init: Initializer = nn.initializers.zeros
273
    bias_axes: Tuple[str, ...] = ("embed",)
274
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
275
    transpose_batch_sequence: bool = False
276

277
    def __post_init__(self):
278
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
279
280
            self.scale_init,
            self.zero_centered_gamma,
281
        )
282
283
        super().__post_init__()

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
299
        x = x.astype(self.dtype)
300
301

        features = x.shape[-1]
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        scale, ln_bias = _create_layernorm_parameters(
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
            self.dtype,
        )
        return layernorm(
            x,
            scale,
            ln_bias,
            layernorm_type=self.layernorm_type,
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
322
323
324
325
326
    """
    Base class of transformer engine
    """

    @staticmethod
327
    def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage:
328
        """
329
        Generate a set of FP8 meta for a GEMM.
330
331
        """

332
333
334
335
336
        input_name_post_fix = f"_i_{postfix}"
        weight_name_post_fix = f"_w_{postfix}"
        grad_name_post_fix = f"_g_{postfix}"

        def generate_a_set(target_postfix):
337
338
339
340
341
342
343
344
            amax = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
                jnp.zeros,
                (FP8Helper.AMAX_HISTORY_LEN,),
                jnp.float32,
                axes=(None,),
            )
345
346
347
348

            scale = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
349
350
                jnp.ones,
                (1,),
351
                jnp.float32,
352
353
                axes=(None,),
            )
354
355
356
357
358
359

            return amax.value, scale.value

        input_amax, input_scale = generate_a_set(input_name_post_fix)
        weight_amax, weight_scale = generate_a_set(weight_name_post_fix)
        grad_amax, grad_scale = generate_a_set(grad_name_post_fix)
360

361
362
363
        return FP8MetaPackage(
            input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale
        )
364
365
366


class DenseGeneral(TransformerEngineBase):
367
368
    r"""
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`.
369
370
371
372

    Parameters
    ----------
    features : Union[Iterable[int], int]
373
        The hidden size of each output sample.
374
375
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
376
377
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
378
    kernel_axes : Tuple[str, ...], default = ()
379
        The name of axes used to shard the weights with a corresponding mesh.
380
    use_bias: bool, default = False
381
382
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
383
    bias_init: Initializer, default = flax.linen.initializers.zeros
384
385
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
386
    bias_axes: Tuple[str, ...], default = ()
387
388
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
389
390
391
392
393
394
395
396
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
397
    axis:  Union[Iterable[int], int], default = -1
398
        An integer tuple with axes to apply the transformation on.
399
400
401
402

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
403
        The data type used to allocate the initial parameters.
404
    transpose_batch_sequence : bool, default = True
405
406
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
407
408
409
410
411
412
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
413
    use_bias: bool = True
414
415
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
416
417
418
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
419
420
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
421
    transpose_batch_sequence: bool = False
422
423
424

    def __post_init__(self):
        if self.kernel_init is None:
425
426
427
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
445

446
447
448
449
450
451
452
453
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
454
        kernel = nn_partitioning.param_with_axes(
455
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
456
        )
457
        kernel = kernel.astype(self.dtype)
458
459

        if self.use_bias:
460
            bias = nn_partitioning.param_with_axes(
461
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
462
            )
463
            bias = bias.astype(self.dtype)
464
465
466
467
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
468
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
469
        if FP8Helper.is_fp8_enabled():
470
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
471

472
473
474
        y = type_safe_dot_general(
            inputs, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
        )
475

476
        if self.enable_low_rank_adaptation:
477
478
479
480
481
482
483
484
485
486
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
487
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
488
489
490
491
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
492
                self.dtype,
493
494
                axes=lora_a_kernel_axes,
            )
495
496
497
498
499
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
500
501
502
503
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
504
                self.dtype,
505
506
                axes=lora_b_kernel_axes,
            )
507
508
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

509
510
511
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
512

513
        if bias is not None:
514
515
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
516
517
518
519
520
521
522
523
524
525
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
526
        The hidden size of each output sample.
527
    enable_layernorm: bool, default = True
528
        Indicate whether to enable layer normalization before linear transformation.
529
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
530
        Indicate the type of layer normalization.
531
    epsilon : float, default = 1e-6
532
        A value added to the denominator of layer normalization for numerical stability.
533
534
535
536
537
538
539
540
541
542
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
543
        Used for initializing scale factors :math:`\gamma`.
544
545
546
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
547
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
548
    scale_axes : Tuple[str, ...], default = ('embed', )
549
550
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
551
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
552
553
554
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
555
556
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
557
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
558
559
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
560
561
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
562
    kernel_axes : Tuple[str, ...], default = ()
563
        The name of axes used to shard the weights with a corresponding mesh.
564
    use_bias: bool, default = False
565
566
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
567
    bias_init: Initializer, default = flax.linen.initializers.zeros
568
569
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
570
    bias_axes: Tuple[str, ...], default = ()
571
572
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
573
    return_layernorm_output: bool, default = True
574
        Indicate whether to return the output of layer normalization.
575
        If set False, return None as the second tensor in outputs.
576
577
578
579
580
581
582
583
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
584
    axis:  Union[Iterable[int], int], default = -1
585
        An integer tuple with axes to apply the transformation on.
586
587
588
589
590
591
592
593
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
594
595
596
597

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
598
        The data type used to allocate the initial parameters.
599
    transpose_batch_sequence : bool, default = True
600
601
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
602
603
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
604
        The factor to scale the output from `DenseGeneral`. It should be a float
605
606
607
608
609
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
610
    layernorm_type: str = "layernorm"
611
    epsilon: float = 1e-6
612
613
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
614
    scale_axes: Tuple[str, ...] = ("embed",)
615
    ln_bias_init: Initializer = nn.initializers.zeros
616
    ln_bias_axes: Tuple[str, ...] = ("embed",)
617
618
619
620
621
622
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
623
624
625
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
626
627
628
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
629
630
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
631
632
633
634
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
635
636
637
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
638
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
639
640
            self.scale_init,
            self.zero_centered_gamma,
641
        )
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
660
            If :attr:`return_layernorm_output=False`, then this would be None.
661
        """
662

663
664
        ln_output = None

665
666
667
668
669
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
670
        inputs = inputs.astype(self.dtype)
671
672

        if self.enable_layernorm:
673
674
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)

675
            assert self.axis == -1  # Only support axis = =-1 at this moment
676
677
            features = inputs.shape[-1]

678
679
680
681
682
683
684
685
686
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
                self.dtype,
            )
687
688

            if not fuse_layernorm:
689
690
691
692
693
694
695
696
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
714
        kernel = nn_partitioning.param_with_axes(
715
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
716
        )
717
        kernel = kernel.astype(self.dtype)
718
719
720

        contract_ind = tuple(range(0, len(axis)))

721
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
722
        if FP8Helper.is_fp8_enabled():
723
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
724
725

        if fuse_layernorm:
726
727
728
729
730
731
732
733
734
735
736
737
            z = layernorm_fp8_dot(
                y,
                kernel,
                scale,
                ln_bias,
                fp8_meta_pkg,
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
            )
738
        else:
739
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
740
741
742
            z = type_safe_dot_general(
                y, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
743

744
        if self.enable_low_rank_adaptation:
745
746
747
748
749
750
751
752
753
754
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
755
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
756
757
758
759
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
760
                self.dtype,
761
762
                axes=lora_a_kernel_axes,
            )
763
764
765
766
767
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
768
769
770
771
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
772
                self.dtype,
773
774
                axes=lora_b_kernel_axes,
            )
775
776
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

777
778
779
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
780

781
782
        bias = None
        if self.use_bias:
783
            bias = nn_partitioning.param_with_axes(
784
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
785
            )
786
            bias = bias.astype(self.dtype)
787
788

        if bias is not None:
789
790
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
791
792
793
794

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

795
        return z, ln_output  # dense_output, layer_norm_output
796
797
798
799
800
801
802
803
804
805


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
806
        Intermediate size to which input samples are projected.
807
    enable_layernorm: bool, default = True
808
        Indicate whether to enable layer normalization before linear transformation.
809
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
810
        Indicate the type of layer normalization.
811
    epsilon : float, default = 1e-6
812
        A value added to the denominator of layer normalization for numerical stability.
813
814
815
816
817
818
819
820
821
822
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
823
        Used for initializing scale factors :math:`\gamma`.
824
825
826
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
827
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
828
    scale_axes : Tuple[str, ...], default = ('embed', )
829
830
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
831
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
832
833
834
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
835
836
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
837
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
838
839
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
840
841
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
842
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
843
        The name of axes used to shard the weights with a corresponding mesh for
844
845
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
846
        The name of axes used to shard the weights with a corresponding mesh for
847
848
        the weight of the second linear transformations.
    use_bias: bool, default = False
849
850
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
851
    bias_init: Initializer, default = flax.linen.initializers.zeros
852
853
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
854
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
855
        The name of axes used to shard bias with a corresponding mesh  for
856
        the weight of the first linear transformations.
857
        Only used when :attr:`use_bias=True`.
858
    bias_axes_2: Tuple[str, ...], default = ('embed',)
859
        The name of axes used to shard bias with a corresponding mesh  for
860
        the weight of the second linear transformations.
861
        Only used when :attr:`use_bias=True`.
862
    return_layernorm_output: bool, default = True
863
        Indicate whether to return the output of layer normalization.
864
865
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
866
        The sequence of activation functions to apply after the first linear transformation.
867
        Each activation has its own transformation layer.
868
869
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
870
    intermediate_dropout_rate: float, default = 0.1
871
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
872
873
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
874
875
876
877
878
879
880
881
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
882
    axis:  Union[Iterable[int], int], default = -1
883
        An integer tuple with axes to apply the transformation on.
884
885
886
887
888
889
890
891
892
893
894
895
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
896
897
898
899

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
900
        The data type used to allocate the initial parameters.
901
    transpose_batch_sequence : bool, default = True
902
903
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
904
905
906
907
908
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
909
    layernorm_type: str = "layernorm"
910
    epsilon: float = 1e-6
911
912
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
913
    scale_axes: Tuple[str, ...] = ("embed",)
914
    ln_bias_init: Initializer = nn.initializers.zeros
915
    ln_bias_axes: Tuple[str, ...] = ("embed",)
916
    kernel_init: Initializer = None
917
918
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
919
920
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
921
922
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
923
    return_layernorm_output: bool = True
924
925
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
926
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
927
    intermediate_hidden_dropout_dims: Sequence[int] = ()
928
929
930
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
931
932
933
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
934
935
936
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
937
938
939

    def __post_init__(self):
        if self.kernel_init is None:
940
941
942
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
943
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
944
945
            self.scale_init,
            self.zero_centered_gamma,
946
        )
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
967
            If :attr:`return_layernorm_output=False`, then this would be None.
968
        """
969

970
971
        ln_output = None

972
973
974
975
976
977
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

978
979
        inputs = inputs.astype(self.dtype)

980
981
982
983
984
985
986
987
        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
988
        normalized_acts = []
989
990
991
        for act in self.activations:
            if not isinstance(act, str):
                return False
992
            normalized_acts.append(act.lower())
993
        normalized_acts = tuple(
994
995
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
996

997
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
998

999
1000
1001
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1002

1003
1004
        # LayerNorm
        if self.enable_layernorm:
1005
            assert self.axis == -1  # Only support axis == -1 at this moment
1006
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1007

1008
1009
            features = inputs.shape[-1]

1010
1011
1012
1013
1014
1015
1016
1017
1018
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
                self.dtype,
            )
1019
1020

            if not fuse_layernorm:
1021
1022
1023
1024
1025
1026
1027
1028
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1043
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1044

1045
1046
        wi_fp8_meta_pkg = None
        wo_fp8_meta_pkg = None
1047
        if FP8Helper.is_fp8_enabled():
1048
1049
            wi_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
            wo_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("1")
1050

1051
        num_activations = len(normalized_acts)
1052
1053
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1054

1055
1056
1057
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
1058
1059
1060
1061
1062
1063
        kernel_1 = nn_partitioning.param_with_axes(
            "wi_kernel",
            kernel_1_init,
            num_activations,
            -2,
            kernel_1_each_shape,
1064
            self.dtype,
1065
1066
            axes=self.kernel_axes_1,
        )
1067
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
1068
        kernel_1 = kernel_1.astype(self.dtype)
1069
1070
1071
1072
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
1073
1074
1075
1076
        kernel_2 = nn_partitioning.param_with_axes(
            "wo_kernel",
            self.kernel_init,
            kernel_2_param_shape,
1077
            self.dtype,
1078
1079
            axes=self.kernel_axes_2,
        )
1080
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
1081
        kernel_2 = kernel_2.astype(self.dtype)
1082
        contract_ind = tuple(range(0, len(axis)))
1083

1084
1085
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1086

1087
        if use_fused_layernorm_mlp:
1088
            assert self.axis == -1  # Only support axis = =-1 at this moment
1089

1090
1091
            if self.use_bias:
                bias_1_shape = intermediate_dim
1092
                bias_1 = nn_partitioning.param_with_axes(
1093
                    "wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1
1094
                )
1095
1096
1097
                bias_1 = bias_1.astype(self.dtype)

                bias_2_shape = (hidden_size,)
1098
                bias_2 = nn_partitioning.param_with_axes(
1099
                    "wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2
1100
                )
1101
1102
                bias_2 = bias_2.astype(self.dtype)
            else:
1103
1104
                bias_1 = None
                bias_2 = None
1105

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
            out = fused_layernorm_fp8_mlp(
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                [wi_fp8_meta_pkg, wo_fp8_meta_pkg],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
                use_bias=self.use_bias,
            )

        else:  # not use_fused_ln_geglu_mlp
1126
1127
            # DenseGeneral 1
            if fuse_layernorm:
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
                x = layernorm_fp8_dot(
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
                    wi_fp8_meta_pkg,
                    self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
                )
1140
            else:
1141
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1142
1143
1144
                x = type_safe_dot_general(
                    y, kernel_1, fp8_meta_pkg=wi_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
                )
1145

1146
            if self.enable_low_rank_adaptation:
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
                wi_lora_a_kernel_shape = (
                    *kernel_1_shape[: len(axis)],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_shape = (
                    kernel_1_each_shape[0],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_each_shape = (
                    kernel_1_each_shape[0],
                    self.low_rank_adaptation_dim,
                )
1161
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
1162
1163
1164
1165
1166
1167
                wi_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_a_kernel",
                    kernel_1_init,
                    num_activations,
                    -2,
                    wi_lora_a_kernel_init_each_shape,
1168
                    self.dtype,
1169
1170
                    axes=wi_lora_a_kernel_axes,
                )
1171
1172
1173
                wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
                wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)

1174
1175
1176
1177
1178
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1179
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1180
1181
1182
1183
                wi_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_b_kernel",
                    nn.initializers.zeros,
                    wi_lora_b_kernel_shape,
1184
                    self.dtype,
1185
1186
                    axes=wi_lora_b_kernel_axes,
                )
1187
1188
                wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)

1189
1190
1191
1192
1193
1194
1195
1196
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
                    intermediate_dim,
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1197

1198
            bias_1 = None
1199
            if self.use_bias:
1200
                bias_1 = nn_partitioning.param_with_axes(
1201
                    "wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1
1202
                )
1203
                bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
1204
                bias_1 = bias_1.astype(self.dtype)
1205
                x += jnp.reshape(bias_1, bias_1_shape)
1206

1207
            x = checkpoint_name(x, ffn1_ckpt_name)
1208
            if is_act_implemented:
1209
                z = activation_lu(x, normalized_acts)
1210
            else:
1211
                activations = []
1212
                x = jnp.split(x, num_activations, axis=-2)
1213
                for idx, act_fn in enumerate(normalized_acts):
1214
1215
1216
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
1217
1218
                # Remove act axis
                z = jnp.reshape(z, (*z.shape[:-2], -1))
1219
1220
            z = z.astype(self.dtype)
            # import pdb; pdb.set_trace()
1221

1222
1223
1224
1225
1226
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1227

1228
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1229
            z = z.astype(self.dtype)
1230

1231
            # DenseGeneral 2
1232
1233
1234
            out = type_safe_dot_general(
                z, kernel_2, fp8_meta_pkg=wo_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
1235

1236
1237
1238
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1239
1240
1241
1242
                wo_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_a_kernel",
                    self.kernel_init,
                    wo_lora_a_kernel_shape,
1243
                    self.dtype,
1244
1245
                    axes=wo_lora_a_kernel_axes,
                )
1246
1247
1248
1249
                wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1250
1251
1252
1253
                wo_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_b_kernel",
                    nn.initializers.zeros,
                    wo_lora_b_kernel_shape,
1254
                    self.dtype,
1255
1256
                    axes=wo_lora_b_kernel_axes,
                )
1257
1258
                wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)

1259
1260
1261
1262
1263
1264
1265
1266
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1267

1268
            bias_2 = None
1269
            if self.use_bias:
1270
                bias_2 = nn_partitioning.param_with_axes(
1271
                    "wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2
1272
                )
1273
1274
                bias_2 = bias_2.astype(self.dtype)
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1275

1276
            out = checkpoint_name(out, ffn2_ckpt_name)
1277

1278
        return out, ln_output  # Output, layner_norm_output