module.py 50.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
9
10
11
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import numpy as np
12
import jax.numpy as jnp
13
14
15
16
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import random as jax_random
17
from jax.ad_checkpoint import checkpoint_name
18

19
20
21
22
23
24
25
from ..dense import dense

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
26
from ..softmax import softmax, SoftmaxType
27
from ..sharding import with_sharding_constraint_by_logical_axes
28
29
30
31
32
33
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
34
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
35
from ..sharding import get_non_contracting_logical_axes
36
37
38
39
40

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
41
42
43
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
44
45
46
47
48
49
50
51
52
53
54
55
56
57
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


58
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
59
60
61
62
63
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
64
65
66
    return nn.initializers.zeros


67
def _create_layernorm_parameters(
68
69
70
71
72
73
74
75
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
76
):
77
78
    scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
    scale = scale.astype(input_dtype)
79

80
81
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
82
        bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
83
        bias = jnp.asarray(bias, input_dtype)
84
    else:
85
        assert norm_type == "rmsnorm"
86
87
88
89
90
91
92
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
93
    if fn_or_string == "linear":
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
108
109
110
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
111
112
113
114
115
116
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


117
118
119
120
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
121
    hidden_in_names = "ijklm"[: len(axis)]
122
    assert len(features) <= 5
123
124
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
125
126
127
128
129
130
131
132
133

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
134
135
136
137
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
138
139
140
141
142
143

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


144
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
145
146
    r"""
    Applies softmax over a mini-batch of inputs.
147
148
149
150
151
152
153
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
154
155
156
157

    Parameters
    ----------
    scale_factor : float, default = 1.0
158
159
160
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
161
162
163
164
165
166
167
168
169
170
171
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
172
        input_dtype = inputs.dtype
173
174
        logits = inputs

175
176
        # use primitives
        if is_softmax_kernel_available(
177
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
178
        ):
179
            if bias is not None:
180
                logits = logits + bias.astype(input_dtype)
181
182
183
184
185

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

186
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
187
        # use default jax based implementation
188
189
        else:
            if bias is not None:
190
                logits = logits + bias.astype(input_dtype)
191

192
193
194
195
196
197
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
198
            else:
199
200
201
202
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
203
        assert input_dtype == outputs.dtype
204
205
206
        return outputs


207
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
237
        A value added to the denominator of layer normalization for numerical stability.
238
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
239
        Indicate the type of layer normalization.
240
241
242
243
244
245
246
247
248
249
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
250
        Used for initializing scale factors :math:`\gamma`.
251
252
253
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
254
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
255
    scale_axes : Tuple[str, ...], default = ('embed', )
256
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
257
    bias_init : Initializer, default = flax.linen.initializers.zeros
258
259
260
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
261
262
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
263
        only used when :attr:`layernorm_type='layernorm'`.
264
265
266

    Optimization parameters
    -----------------------
267
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
268
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
269
    transpose_batch_sequence : bool, default = False
270
271
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
272
273
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
274

275
    epsilon: float = 1e-6
276
    layernorm_type: str = "layernorm"
277
278
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
279
    scale_axes: Tuple[str, ...] = ("embed",)
280
    bias_init: Initializer = nn.initializers.zeros
281
    bias_axes: Tuple[str, ...] = ("embed",)
282
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
283
    transpose_batch_sequence: bool = False
284

285
    def __post_init__(self):
286
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
287
288
            self.scale_init,
            self.zero_centered_gamma,
289
        )
290
291
        super().__post_init__()

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
307
        input_dtype = x.dtype
308
309

        features = x.shape[-1]
310
311
312
313
314
315
316
        scale, ln_bias = _create_layernorm_parameters(
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
317
            input_dtype,
318
319
            self.dtype,
        )
320
        out = layernorm(
321
322
323
            x,
            scale,
            ln_bias,
324
            norm_type=self.layernorm_type,
325
326
327
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
328
329
        assert out.dtype == input_dtype
        return out
330
331
332


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
333
334
335
336
    """
    Base class of transformer engine
    """

337
    def generate_quantizer_set(self, postfix: str = ""):
338
        """
339
        Generate a set of FP8 meta for a GEMM.
340
341
        """

342
343
344
345
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
346
347
                jnp.ones,
                (1,),
348
                jnp.float32,
349
350
351
352
353
354
355
356
357
358
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

359
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
360
361
362
363
364
365
366
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
367

368
369
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
370
371
372


class DenseGeneral(TransformerEngineBase):
373
    r"""
374
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
375
376
377
378

    Parameters
    ----------
    features : Union[Iterable[int], int]
379
        The hidden size of each output sample.
380
381
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
382
383
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
384
    kernel_axes : Tuple[str, ...], default = ()
385
        The name of axes used to shard the weights with a corresponding mesh.
386
    use_bias: bool, default = False
387
388
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
389
    bias_init: Initializer, default = flax.linen.initializers.zeros
390
391
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
392
    bias_axes: Tuple[str, ...], default = ()
393
394
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
395
    enable_low_rank_adaptation: bool, default = False
396
        Indicate whether to enable low rank adaptation for each dense layer.
397
398
399
400
401
402
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
403
    axis:  Union[Iterable[int], int], default = -1
404
        An integer tuple with axes to apply the transformation on.
405
406
407
408
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
409
410
411

    Optimization parameters
    -----------------------
412
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
413
        The data type used to allocate the initial parameters.
414
    transpose_batch_sequence : bool, default = True
415
416
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
417
418
419
420
421
422
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
423
    use_bias: bool = True
424
425
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
426
427
428
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
429
430
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
431
    transpose_batch_sequence: bool = False
432
    input_axes: Tuple[str, ...] = ()
433
434
435

    def __post_init__(self):
        if self.kernel_init is None:
436
            self.kernel_init = nn.initializers.variance_scaling(
437
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
438
            )
439
440
441
442
443
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
444
        Apply the dense layer transformation to the input.
445
446
447
448
449
450
451
452
453
454
455

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
456

457
        input_dtype = inputs.dtype
458
459
460
461
462
463
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
464
465
466
467
468
469

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
470
        kernel = nn_partitioning.param_with_axes(
471
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
472
        )
473

474
        if not QuantizeConfig.is_fp8_enabled():
475
            kernel = kernel.astype(input_dtype)
476
477

        if self.use_bias:
478
            bias = nn_partitioning.param_with_axes(
479
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
480
            ).astype(input_dtype)
481
482
483
        else:
            bias = None

484
        quantizer_set = self.generate_quantizer_set()
485
        contract_ind = tuple(range(0, len(axis)))
486
        y = dense(
487
488
489
490
491
492
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
493
        )
494

495
        if self.enable_low_rank_adaptation:
496
497
498
499
500
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
501
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
502
503
504
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
505
                lora_a_kernel_shape,
506
                self.dtype,
507
508
                axes=lora_a_kernel_axes,
            )
509
            lora_a_kernel = lora_a_kernel.astype(input_dtype)
510
511
512

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
513
514
515
516
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
517
                self.dtype,
518
519
                axes=lora_b_kernel_axes,
            )
520
            lora_b_kernel = lora_b_kernel.astype(input_dtype)
521

522
523
524
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
525

526
        if bias is not None:
527
528
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
529
530

        assert y.dtype == input_dtype
531
532
533
534
535
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
536
    Applies layer normalization followed by dense layer transformation to the incoming data.
537
538
539
540

    Parameters
    ----------
    features : Union[Iterable[int], int]
541
        The hidden size of each output sample.
542
    enable_layernorm: bool, default = True
543
        Indicate whether to enable layer normalization before dense layer transformation.
544
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
545
        Indicate the type of layer normalization.
546
    epsilon : float, default = 1e-6
547
        A value added to the denominator of layer normalization for numerical stability.
548
549
550
551
552
553
554
555
556
557
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
558
        Used for initializing scale factors :math:`\gamma`.
559
560
561
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
562
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
563
    scale_axes : Tuple[str, ...], default = ('embed', )
564
565
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
566
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
567
568
569
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
570
571
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
572
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
573
574
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
575
576
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
577
    kernel_axes : Tuple[str, ...], default = ()
578
        The name of axes used to shard the weights with a corresponding mesh.
579
    use_bias: bool, default = False
580
581
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
582
    bias_init: Initializer, default = flax.linen.initializers.zeros
583
584
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
585
    bias_axes: Tuple[str, ...], default = ()
586
587
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
588
    return_layernorm_output: bool, default = True
589
        Indicate whether to return the output of layer normalization.
590
        If set False, return None as the second tensor in outputs.
591
    enable_low_rank_adaptation: bool, default = False
592
        Indicate whether to enable low rank adaptation for each dense layer.
593
594
595
596
597
598
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
599
    axis:  Union[Iterable[int], int], default = -1
600
        An integer tuple with axes to apply the transformation on.
601
602
603
604
605
606
607
608
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
609
610
611

    Optimization parameters
    -----------------------
612
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
613
        The data type used to allocate the initial parameters.
614
    transpose_batch_sequence : bool, default = True
615
616
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
617
618
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
619
        The factor to scale the output from `DenseGeneral`. It should be a float
620
621
622
623
624
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
625
    layernorm_type: str = "layernorm"
626
    epsilon: float = 1e-6
627
628
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
629
    scale_axes: Tuple[str, ...] = ("embed",)
630
    ln_bias_init: Initializer = nn.initializers.zeros
631
    ln_bias_axes: Tuple[str, ...] = ("embed",)
632
633
634
635
636
637
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
638
639
640
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
641
642
643
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
644
645
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
646
647
648
649
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
650
            self.kernel_init = nn.initializers.variance_scaling(
651
652
653
                1.0,
                "fan_in",
                "truncated_normal",
654
                dtype=self.dtype,
655
            )
656
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
657
658
            self.scale_init,
            self.zero_centered_gamma,
659
        )
660
        self.quantizer_set = QuantizerFactory.create_set()
661
662
663
664
665
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
666
        Apply layer normalization to the input followed by a dense layer transformation.
667
668
669
670
671
672
673
674
675
676
677
678

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
679
            If :attr:`return_layernorm_output=False`, then this would be None.
680
        """
681
        assert self.axis == -1, "Only support axis = =-1 at this moment"
682

683
        input_dtype = inputs.dtype
684
685
        ln_output = None

686
687
        quantizer_set = self.generate_quantizer_set()

688
        fuse_layernorm = (
689
            QuantizeConfig.is_fp8_enabled()
690
691
692
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
693
694

        if self.enable_layernorm:
695
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
696
            features = inputs.shape[-1]
697
698
699
700
701
702
703
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
704
                input_dtype,
705
706
                self.dtype,
            )
707
708

            if not fuse_layernorm:
709
710
711
712
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
713
                    norm_type=self.layernorm_type,
714
715
716
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

732
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
733
        kernel = nn_partitioning.param_with_axes(
734
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
735
        )
736
        if not QuantizeConfig.is_fp8_enabled():
737
            kernel = kernel.astype(input_dtype)
738
739
740

        contract_ind = tuple(range(0, len(axis)))

741
        if fuse_layernorm:
742
            z = layernorm_dense(
743
744
745
746
                y,
                kernel,
                scale,
                ln_bias,
747
                norm_type=self.layernorm_type,
748
749
750
751
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
752
                kernel_axes=self.kernel_axes,
753
                quantizer_set=quantizer_set,
754
            )
755
        else:
756
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
757
758
759
760
761
762
763
764
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
765

766
        if self.enable_low_rank_adaptation:
767
768
769
770
771
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
772
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
773
774
775
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
776
                lora_a_kernel_shape,
777
                self.dtype,
778
779
                axes=lora_a_kernel_axes,
            )
780
            lora_a_kernel = lora_a_kernel.astype(input_dtype)
781
782
783

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
784
785
786
787
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
788
                self.dtype,
789
790
                axes=lora_b_kernel_axes,
            )
791
            lora_b_kernel = lora_b_kernel.astype(input_dtype)
792

793
794
795
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
796

797
798
        bias = None
        if self.use_bias:
799
            bias = nn_partitioning.param_with_axes(
800
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
801
            ).astype(input_dtype)
802
803

        if bias is not None:
804
805
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
806
807
808
809

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

810
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
811
        # z = z.reshape(*inputs.shape[: self.axis], *features)
812
        return z, ln_output  # dense_output, layer_norm_output
813
814
815
816
817


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
818
    consisting of 2 successive dense layer transformations, separated by given activations.
819
820
821
822

    Parameters
    ----------
    intermediate_dim: int, default = 2048
823
        Intermediate size to which input samples are projected.
824
    enable_layernorm: bool, default = True
825
        Indicate whether to enable layer normalization before dense layer transformation.
826
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
827
        Indicate the type of layer normalization.
828
    epsilon : float, default = 1e-6
829
        A value added to the denominator of layer normalization for numerical stability.
830
831
832
833
834
835
836
837
838
839
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
840
        Used for initializing scale factors :math:`\gamma`.
841
842
843
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
844
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
845
    scale_axes : Tuple[str, ...], default = ('embed', )
846
847
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
848
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
849
850
851
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
852
853
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
854
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
855
856
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
857
        Used for initializing the weights of both dense layer transformations.
858
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
859
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
860
        The name of axes used to shard the weights with a corresponding mesh for
861
        the weight of the first dense layer transformation.
862
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
863
        The name of axes used to shard the weights with a corresponding mesh for
864
        the weight of the second dense layer transformation.
865
    use_bias: bool, default = False
866
867
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
868
    bias_init: Initializer, default = flax.linen.initializers.zeros
869
870
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
871
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
872
        The name of axes used to shard bias with a corresponding mesh  for
873
        the weight of the first dense layer transformation.
874
        Only used when :attr:`use_bias=True`.
875
    bias_axes_2: Tuple[str, ...], default = ('embed',)
876
        The name of axes used to shard bias with a corresponding mesh  for
877
        the weight of the second dense layer transformation.
878
        Only used when :attr:`use_bias=True`.
879
    return_layernorm_output: bool, default = True
880
        Indicate whether to return the output of layer normalization.
881
882
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
883
        The sequence of activation functions to apply after the first dense layer transformation.
884
        Each activation has its own transformation layer.
885
886
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
887
    intermediate_dropout_rate: float, default = 0.1
888
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
889
890
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
891
    enable_low_rank_adaptation: bool, default = False
892
        Indicate whether to enable low rank adaptation for each dense layer.
893
894
895
896
897
898
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
899
    axis:  Union[Iterable[int], int], default = -1
900
        An integer tuple with axes to apply the transformation on.
901
902
903
904
905
906
907
908
909
910
911
912
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
913
914
915

    Optimization parameters
    -----------------------
916
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
917
        The data type used to allocate the initial parameters.
918
    transpose_batch_sequence : bool, default = True
919
920
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
921
922
923
924
925
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
926
    layernorm_type: str = "layernorm"
927
    epsilon: float = 1e-6
928
929
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
930
    scale_axes: Tuple[str, ...] = ("embed",)
931
    ln_bias_init: Initializer = nn.initializers.zeros
932
    ln_bias_axes: Tuple[str, ...] = ("embed",)
933
    kernel_init: Initializer = None
934
935
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
936
937
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
938
939
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
940
    return_layernorm_output: bool = True
941
942
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
943
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
944
    intermediate_hidden_dropout_dims: Sequence[int] = ()
945
946
947
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
948
949
950
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
951
952
953
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
954
955
956

    def __post_init__(self):
        if self.kernel_init is None:
957
            self.kernel_init = nn.initializers.variance_scaling(
958
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
959
            )
960
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
961
962
            self.scale_init,
            self.zero_centered_gamma,
963
        )
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
984
            If :attr:`return_layernorm_output=False`, then this would be None.
985
        """
986
987
        assert self.axis == -1, "Only support axis == -1 at this moment"

988
989
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
990

991
        input_dtype = inputs.dtype
992
993
        ln_output = None

994
995
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
996
        fuse_layernorm = (
997
            QuantizeConfig.is_fp8_enabled()
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1010
        normalized_acts = []
1011
1012
1013
        for act in self.activations:
            if not isinstance(act, str):
                return False
1014
            normalized_acts.append(act.lower())
1015
        normalized_acts = tuple(
1016
1017
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1018

1019
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1020

1021
1022
1023
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1024
1025
        # LayerNorm
        if self.enable_layernorm:
1026
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1027

1028
1029
            features = inputs.shape[-1]

1030
1031
1032
1033
1034
1035
1036
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1037
                input_dtype,
1038
1039
                self.dtype,
            )
1040
1041

            if not fuse_layernorm:
1042
1043
1044
1045
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1046
                    norm_type=self.layernorm_type,
1047
1048
1049
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1064
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1065

1066
        num_activations = len(normalized_acts)
1067
1068
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1069
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1070
1071
1072
1073
1074
1075
        kernel_1 = nn_partitioning.param_with_axes(
            "wi_kernel",
            kernel_1_init,
            num_activations,
            -2,
            kernel_1_each_shape,
1076
            self.dtype,
1077
1078
            axes=self.kernel_axes_1,
        )
1079

1080
        if not QuantizeConfig.is_fp8_enabled():
1081
            kernel_1 = kernel_1.astype(input_dtype)
1082

1083
1084
1085
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1086
1087
1088
        kernel_2 = nn_partitioning.param_with_axes(
            "wo_kernel",
            self.kernel_init,
1089
            kernel_2_shape,
1090
            self.dtype,
1091
1092
            axes=self.kernel_axes_2,
        )
1093
        if not QuantizeConfig.is_fp8_enabled():
1094
            kernel_2 = kernel_2.astype(input_dtype)
1095

1096
        contract_ind = tuple(range(0, len(axis)))
1097

1098
        if self.use_bias:
1099
            bias_1_shape = (num_activations, self.intermediate_dim)
1100
1101
1102
1103
1104
1105
            bias_1 = nn_partitioning.param_with_axes(
                "wi_bias",
                self.bias_init,
                bias_1_shape,
                self.dtype,
                axes=self.bias_axes_1,
1106
            ).astype(input_dtype)
1107
1108
1109
1110
1111
1112
1113
1114

            bias_2_shape = (hidden_size,)
            bias_2 = nn_partitioning.param_with_axes(
                "wo_bias",
                self.bias_init,
                bias_2_shape,
                self.dtype,
                axes=self.bias_axes_2,
1115
            ).astype(input_dtype)
1116
1117
1118
1119
        else:
            bias_1 = None
            bias_2 = None

1120
1121
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1122

1123
        if use_fused_layernorm_mlp:
1124
            out = layernorm_mlp(
1125
1126
1127
1128
1129
1130
1131
1132
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1133
                norm_input_axes=self.layernorm_input_axes,
1134
1135
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1136
1137
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1138
1139
1140
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
1141
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1142
            )
1143
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1144
1145

        else:  # not use_fused_ln_geglu_mlp
1146
1147
            # DenseGeneral 1
            if fuse_layernorm:
1148
                x = layernorm_dense(
1149
1150
1151
1152
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1153
                    norm_type=self.layernorm_type,
1154
1155
1156
1157
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1158
                    kernel_axes=self.kernel_axes_1,
1159
                    quantizer_set=ffn1_quantizer_set,
1160
                )
1161
            else:
1162
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1163
1164
1165
1166
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1167
1168
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1169
                    quantizer_set=ffn1_quantizer_set,
1170
                )
1171
1172
1173
1174
1175
            dot_1_output_axes = (
                *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
                *get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
            )
            x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
1176

1177
            if self.enable_low_rank_adaptation:
1178
1179
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1180
1181
                    self.low_rank_adaptation_dim,
                )
1182
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1183
1184
1185
1186
                wi_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_a_kernel",
                    kernel_1_init,
                    num_activations,
1187
1188
                    -2,
                    wi_lora_a_kernel_each_shape,
1189
                    self.dtype,
1190
1191
                    axes=wi_lora_a_kernel_axes,
                )
1192
                wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
1193

1194
1195
1196
1197
1198
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1199
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1200
1201
1202
1203
                wi_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_b_kernel",
                    nn.initializers.zeros,
                    wi_lora_b_kernel_shape,
1204
                    self.dtype,
1205
1206
                    axes=wi_lora_b_kernel_axes,
                )
1207
                wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
1208

1209
1210
1211
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1212
                    (num_activations, self.intermediate_dim),
1213
1214
1215
1216
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1217

1218
            if self.use_bias:
1219
                x += jnp.reshape(bias_1, bias_1_shape)
1220

1221
            x = checkpoint_name(x, ffn1_ckpt_name)
1222
            if is_act_implemented:
1223
                z = activation(x, normalized_acts)
1224
            else:
1225
                activations = []
1226
                x = jnp.split(x, num_activations, axis=-2)
1227
                for idx, act_fn in enumerate(normalized_acts):
1228
1229
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1230
                z = reduce(operator.mul, activations)
1231
                z = jnp.squeeze(z, axis=-2)
1232
            z = z.astype(input_dtype)
1233

1234
1235
1236
1237
1238
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1239

1240
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1241
            z = z.astype(input_dtype)
1242

1243
            # DenseGeneral 2
1244
            out = dense(
1245
1246
1247
1248
1249
1250
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1251
            )
1252

1253
1254
1255
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1256
1257
1258
1259
                wo_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_a_kernel",
                    self.kernel_init,
                    wo_lora_a_kernel_shape,
1260
                    self.dtype,
1261
1262
                    axes=wo_lora_a_kernel_axes,
                )
1263
                wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
1264
1265
1266

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1267
1268
1269
1270
                wo_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_b_kernel",
                    nn.initializers.zeros,
                    wo_lora_b_kernel_shape,
1271
                    self.dtype,
1272
1273
                    axes=wo_lora_b_kernel_axes,
                )
1274
                wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
1275

1276
1277
1278
1279
1280
1281
1282
1283
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1284

1285
            if self.use_bias:
1286
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1287

1288
            out = checkpoint_name(out, ffn2_ckpt_name)
1289

1290
        assert out.dtype == input_dtype
1291
        return out, ln_output  # Output, layner_norm_output