module.py 51.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

Alp Dener's avatar
Alp Dener committed
18
from ..dense import dense, _issue_batch_first_warning as _dense_warning
19
20
21

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
Alp Dener's avatar
Alp Dener committed
22
23
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
24
from ..activation import activation
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
28
29
30
31
32
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
33
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34
35
36

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
37
38
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
39
40
41
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
42
43
44
45
46
47
48
49
50
51
52
53
54
55
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


56
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
57
58
59
60
61
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
62
63
64
    return nn.initializers.zeros


65
def _create_layernorm_parameters(
66
    module,
67
68
69
70
71
72
73
74
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
75
):
76
77
78
79
80
81
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
82

83
84
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
85
86
87
88
89
90
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
91
    else:
92
        assert norm_type == "rmsnorm"
93
94
95
96
97
98
99
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
100
    if fn_or_string == "linear":
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
115
116
117
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
118
119
120
121
122
123
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


124
125
126
127
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
128
    hidden_in_names = "ijklm"[: len(axis)]
129
    assert len(features) <= 5
130
131
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
132
133
134
135
136
137
138
139
140

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
141
142
143
144
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
145
146
147
148
149
150

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


151
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
152
153
    r"""
    Applies softmax over a mini-batch of inputs.
154
155
156
157
158
159
160
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
161
162
163
164

    Parameters
    ----------
    scale_factor : float, default = 1.0
165
166
167
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
168
169
170
171
172
173
174
175
176
177
178
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
179
        input_dtype = inputs.dtype
180
181
        logits = inputs

182
183
        # use primitives
        if is_softmax_kernel_available(
184
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
185
        ):
186
            if bias is not None:
187
                logits = logits + bias.astype(input_dtype)
188
189
190
191
192

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

193
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
194
        # use default jax based implementation
195
196
        else:
            if bias is not None:
197
                logits = logits + bias.astype(input_dtype)
198

199
200
201
202
203
204
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
205
            else:
206
207
208
209
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
210
        assert input_dtype == outputs.dtype
211
212
213
        return outputs


214
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
244
        A value added to the denominator of layer normalization for numerical stability.
245
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
246
        Indicate the type of layer normalization.
247
248
249
250
251
252
253
254
255
256
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
257
        Used for initializing scale factors :math:`\gamma`.
258
259
260
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
261
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
262
    scale_axes : Tuple[str, ...], default = ('embed', )
263
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
264
    bias_init : Initializer, default = flax.linen.initializers.zeros
265
266
267
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
268
269
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
270
        only used when :attr:`layernorm_type='layernorm'`.
271
272
273

    Optimization parameters
    -----------------------
274
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
275
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
276
    transpose_batch_sequence : bool, default = False
277
278
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
279
280
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
281

282
    epsilon: float = 1e-6
283
    layernorm_type: str = "layernorm"
284
285
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
286
    scale_axes: Tuple[str, ...] = ("embed",)
287
    bias_init: Initializer = nn.initializers.zeros
288
    bias_axes: Tuple[str, ...] = ("embed",)
289
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
290
    transpose_batch_sequence: bool = False
291

292
    def __post_init__(self):
293
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
294
295
            self.scale_init,
            self.zero_centered_gamma,
296
        )
297
298
        super().__post_init__()

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
314
        input_dtype = x.dtype
315
316

        features = x.shape[-1]
317
        scale, ln_bias = _create_layernorm_parameters(
318
            self,
319
320
321
322
323
324
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
325
            input_dtype,
326
327
            self.dtype,
        )
328
        out = layernorm(
329
330
331
            x,
            scale,
            ln_bias,
332
            norm_type=self.layernorm_type,
333
334
335
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
336
337
        assert out.dtype == input_dtype
        return out
338
339
340


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
341
342
343
344
    """
    Base class of transformer engine
    """

345
    def generate_quantizer_set(self, postfix: str = ""):
346
        """
347
        Generate a set of FP8 meta for a GEMM.
348
349
        """

350
351
352
353
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
354
355
                jnp.ones,
                (1,),
356
                jnp.float32,
357
358
359
360
361
362
363
364
365
366
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

367
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
368
369
370
371
372
373
374
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
375

376
377
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
378
379
380


class DenseGeneral(TransformerEngineBase):
381
    r"""
382
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
383
384
385
386

    Parameters
    ----------
    features : Union[Iterable[int], int]
387
        The hidden size of each output sample.
388
389
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
390
391
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
392
    kernel_axes : Tuple[str, ...], default = ()
393
        The name of axes used to shard the weights with a corresponding mesh.
394
    use_bias: bool, default = False
395
396
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
397
    bias_init: Initializer, default = flax.linen.initializers.zeros
398
399
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
400
    bias_axes: Tuple[str, ...], default = ()
401
402
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
403
    enable_low_rank_adaptation: bool, default = False
404
        Indicate whether to enable low rank adaptation for each dense layer.
405
406
407
408
409
410
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
411
    axis:  Union[Iterable[int], int], default = -1
412
        An integer tuple with axes to apply the transformation on.
413
414
415
416
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
417
418
    sequence_parallel_output: bool, default = False
        Produce a sequence-parallel output with the first non-batch dimension sharded over
419
420
421

    Optimization parameters
    -----------------------
422
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
423
        The data type used to allocate the initial parameters.
424
    transpose_batch_sequence : bool, default = True
425
426
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
427
428
429
430
431
432
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
433
    use_bias: bool = True
434
435
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
436
437
438
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
439
440
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
441
    transpose_batch_sequence: bool = False
442
    input_axes: Tuple[str, ...] = ()
443
    sequence_parallel_output: bool = False
444
445

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
446
447
448
449
450
451
        if self.transpose_batch_sequence:
            _dense_warning(
                "TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
452
        if self.kernel_init is None:
453
            self.kernel_init = nn.initializers.variance_scaling(
454
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
455
            )
456
457
458
459
460
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
461
        Apply the dense layer transformation to the input.
462
463
464
465
466
467
468
469
470
471
472

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
473

474
        input_dtype = inputs.dtype
475
476
477
478
479
480
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
481
482
483
484
485
486

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
487
488
489
490
491
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
492
        )
493

494
        if not QuantizeConfig.is_fp8_enabled():
495
            kernel = kernel.astype(input_dtype)
496
497

        if self.use_bias:
498
499
500
501
502
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
503
            ).astype(input_dtype)
504
505
506
        else:
            bias = None

507
        quantizer_set = self.generate_quantizer_set()
508
        contract_ind = tuple(range(0, len(axis)))
509
        y = dense(
510
511
512
513
514
515
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
516
            sequence_parallel_output=self.sequence_parallel_output,
517
        )
518

519
        if self.enable_low_rank_adaptation:
520
521
522
523
524
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
525
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
526
            lora_a_kernel = self.param(
527
                "lora_a_kernel",
528
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
529
                lora_a_kernel_shape,
530
                self.dtype,
531
            ).astype(input_dtype)
532
533
534

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
535
            lora_b_kernel = self.param(
536
                "lora_b_kernel",
537
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
538
                lora_b_kernel_shape,
539
                self.dtype,
540
            ).astype(input_dtype)
541

542
543
544
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
545

546
        if bias is not None:
547
548
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
549
550

        assert y.dtype == input_dtype
551
552
553
554
555
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
556
    Applies layer normalization followed by dense layer transformation to the incoming data.
557
558
559
560

    Parameters
    ----------
    features : Union[Iterable[int], int]
561
        The hidden size of each output sample.
562
    enable_layernorm: bool, default = True
563
        Indicate whether to enable layer normalization before dense layer transformation.
564
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
565
        Indicate the type of layer normalization.
566
    epsilon : float, default = 1e-6
567
        A value added to the denominator of layer normalization for numerical stability.
568
569
570
571
572
573
574
575
576
577
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
578
        Used for initializing scale factors :math:`\gamma`.
579
580
581
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
582
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
583
    scale_axes : Tuple[str, ...], default = ('embed', )
584
585
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
586
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
587
588
589
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
590
591
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
592
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
593
594
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
595
596
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
597
    kernel_axes : Tuple[str, ...], default = ()
598
        The name of axes used to shard the weights with a corresponding mesh.
599
    use_bias: bool, default = False
600
601
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
602
    bias_init: Initializer, default = flax.linen.initializers.zeros
603
604
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
605
    bias_axes: Tuple[str, ...], default = ()
606
607
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
608
    return_layernorm_output: bool, default = True
609
        Indicate whether to return the output of layer normalization.
610
        If set False, return None as the second tensor in outputs.
611
    enable_low_rank_adaptation: bool, default = False
612
        Indicate whether to enable low rank adaptation for each dense layer.
613
614
615
616
617
618
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
619
    axis:  Union[Iterable[int], int], default = -1
620
        An integer tuple with axes to apply the transformation on.
621
622
623
624
625
626
627
628
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
629
630
631

    Optimization parameters
    -----------------------
632
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
633
        The data type used to allocate the initial parameters.
634
    transpose_batch_sequence : bool, default = True
635
636
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
637
638
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
639
        The factor to scale the output from `DenseGeneral`. It should be a float
640
641
642
643
644
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
645
    layernorm_type: str = "layernorm"
646
    epsilon: float = 1e-6
647
648
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
649
    scale_axes: Tuple[str, ...] = ("embed",)
650
    ln_bias_init: Initializer = nn.initializers.zeros
651
    ln_bias_axes: Tuple[str, ...] = ("embed",)
652
653
654
655
656
657
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
658
659
660
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
661
662
663
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
664
665
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
666
667
668
    depth_scaling: float = None

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
669
670
671
672
673
674
        if self.transpose_batch_sequence:
            _ln_dense_warning(
                "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
                "inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
                "Use sequence-first inputs at your own discretion."
            )
675
        if self.kernel_init is None:
676
            self.kernel_init = nn.initializers.variance_scaling(
677
678
679
                1.0,
                "fan_in",
                "truncated_normal",
680
                dtype=self.dtype,
681
            )
682
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
683
684
            self.scale_init,
            self.zero_centered_gamma,
685
        )
686
        self.quantizer_set = QuantizerFactory.create_set()
687
688
689
690
691
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
692
        Apply layer normalization to the input followed by a dense layer transformation.
693
694
695
696
697
698
699
700
701
702
703
704

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
705
            If :attr:`return_layernorm_output=False`, then this would be None.
706
        """
707
        assert self.axis == -1, "Only support axis = =-1 at this moment"
708

709
        input_dtype = inputs.dtype
710
711
        ln_output = None

712
713
        quantizer_set = self.generate_quantizer_set()

714
        fuse_layernorm = (
715
            QuantizeConfig.is_fp8_enabled()
716
717
718
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
719
720

        if self.enable_layernorm:
721
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
722
            features = inputs.shape[-1]
723
            scale, ln_bias = _create_layernorm_parameters(
724
                self,
725
726
727
728
729
730
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
731
                input_dtype,
732
733
                self.dtype,
            )
734
735

            if not fuse_layernorm:
736
737
738
739
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
740
                    norm_type=self.layernorm_type,
741
742
743
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

759
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
760
761
762
763
764
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
765
        )
766
        if not QuantizeConfig.is_fp8_enabled():
767
            kernel = kernel.astype(input_dtype)
768
769
770

        contract_ind = tuple(range(0, len(axis)))

771
        if fuse_layernorm:
772
            z = layernorm_dense(
773
774
775
776
                y,
                kernel,
                scale,
                ln_bias,
777
                norm_type=self.layernorm_type,
778
779
780
781
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
782
                kernel_axes=self.kernel_axes,
783
                quantizer_set=quantizer_set,
784
            )
785
        else:
786
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
787
788
789
790
791
792
793
794
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
795

796
        if self.enable_low_rank_adaptation:
797
798
799
800
801
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
802
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
803
            lora_a_kernel = self.param(
804
                "lora_a_kernel",
805
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
806
                lora_a_kernel_shape,
807
                self.dtype,
808
            ).astype(input_dtype)
809
810
811

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
812
            lora_b_kernel = self.param(
813
                "lora_b_kernel",
814
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
815
                lora_b_kernel_shape,
816
                self.dtype,
817
            ).astype(input_dtype)
818

819
820
821
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
822

823
824
        bias = None
        if self.use_bias:
825
826
827
828
829
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
830
            ).astype(input_dtype)
831
832

        if bias is not None:
833
834
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
835
836
837
838

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

839
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
840
        # z = z.reshape(*inputs.shape[: self.axis], *features)
841
        return z, ln_output  # dense_output, layer_norm_output
842
843
844
845
846


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
847
    consisting of 2 successive dense layer transformations, separated by given activations.
848
849
850
851

    Parameters
    ----------
    intermediate_dim: int, default = 2048
852
        Intermediate size to which input samples are projected.
853
    enable_layernorm: bool, default = True
854
        Indicate whether to enable layer normalization before dense layer transformation.
855
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
856
        Indicate the type of layer normalization.
857
    epsilon : float, default = 1e-6
858
        A value added to the denominator of layer normalization for numerical stability.
859
860
861
862
863
864
865
866
867
868
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
869
        Used for initializing scale factors :math:`\gamma`.
870
871
872
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
873
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
874
    scale_axes : Tuple[str, ...], default = ('embed', )
875
876
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
877
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
878
879
880
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
881
882
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
883
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
884
885
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
886
        Used for initializing the weights of both dense layer transformations.
887
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
888
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
889
        The name of axes used to shard the weights with a corresponding mesh for
890
        the weight of the first dense layer transformation.
891
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
892
        The name of axes used to shard the weights with a corresponding mesh for
893
        the weight of the second dense layer transformation.
894
    use_bias: bool, default = False
895
896
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
897
    bias_init: Initializer, default = flax.linen.initializers.zeros
898
899
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
900
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
901
        The name of axes used to shard bias with a corresponding mesh  for
902
        the weight of the first dense layer transformation.
903
        Only used when :attr:`use_bias=True`.
904
    bias_axes_2: Tuple[str, ...], default = ('embed',)
905
        The name of axes used to shard bias with a corresponding mesh  for
906
        the weight of the second dense layer transformation.
907
        Only used when :attr:`use_bias=True`.
908
    return_layernorm_output: bool, default = True
909
        Indicate whether to return the output of layer normalization.
910
911
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
912
        The sequence of activation functions to apply after the first dense layer transformation.
913
        Each activation has its own transformation layer.
914
915
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
916
    intermediate_dropout_rate: float, default = 0.1
917
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
918
919
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
920
    enable_low_rank_adaptation: bool, default = False
921
        Indicate whether to enable low rank adaptation for each dense layer.
922
923
924
925
926
927
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
928
    axis:  Union[Iterable[int], int], default = -1
929
        An integer tuple with axes to apply the transformation on.
930
931
932
933
934
935
936
937
938
939
940
941
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
942
943
944
945
946
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

947
948
949

    Optimization parameters
    -----------------------
950
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
951
        The data type used to allocate the initial parameters.
952
    transpose_batch_sequence : bool, default = True
953
954
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
955
956
957
958
959
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
960
    layernorm_type: str = "layernorm"
961
    epsilon: float = 1e-6
962
963
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
964
    scale_axes: Tuple[str, ...] = ("embed",)
965
    ln_bias_init: Initializer = nn.initializers.zeros
966
    ln_bias_axes: Tuple[str, ...] = ("embed",)
967
    kernel_init: Initializer = None
968
969
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
970
971
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
972
973
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
974
    return_layernorm_output: bool = True
975
976
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
977
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
978
    intermediate_hidden_dropout_dims: Sequence[int] = ()
979
980
981
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
982
983
984
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
985
986
987
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
988
989
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
990
991

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
992
993
994
995
996
997
        if self.transpose_batch_sequence:
            _ln_mlp_warning(
                "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
998
        if self.kernel_init is None:
999
            self.kernel_init = nn.initializers.variance_scaling(
1000
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1001
            )
1002
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
1003
1004
            self.scale_init,
            self.zero_centered_gamma,
1005
        )
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1026
            If :attr:`return_layernorm_output=False`, then this would be None.
1027
        """
1028
1029
        assert self.axis == -1, "Only support axis == -1 at this moment"

1030
1031
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1032

1033
        input_dtype = inputs.dtype
1034
1035
        ln_output = None

1036
1037
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1038
        fuse_layernorm = (
1039
            QuantizeConfig.is_fp8_enabled()
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1052
        normalized_acts = []
1053
1054
1055
        for act in self.activations:
            if not isinstance(act, str):
                return False
1056
            normalized_acts.append(act.lower())
1057
        normalized_acts = tuple(
1058
1059
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1060

1061
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1062

1063
1064
1065
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1066
1067
        # LayerNorm
        if self.enable_layernorm:
1068
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1069

1070
1071
            features = inputs.shape[-1]

1072
            scale, ln_bias = _create_layernorm_parameters(
1073
                self,
1074
1075
1076
1077
1078
1079
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1080
                input_dtype,
1081
1082
                self.dtype,
            )
1083
1084

            if not fuse_layernorm:
1085
1086
1087
1088
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1089
                    norm_type=self.layernorm_type,
1090
1091
1092
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1107
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1108

1109
        num_activations = len(normalized_acts)
1110
1111
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1112
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1113
        kernel_1 = self.param(
1114
            "wi_kernel",
1115
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1116
1117
1118
            num_activations,
            -2,
            kernel_1_each_shape,
1119
            self.dtype,
1120
        )
1121

1122
        if not QuantizeConfig.is_fp8_enabled():
1123
            kernel_1 = kernel_1.astype(input_dtype)
1124

1125
1126
1127
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1128
        kernel_2 = self.param(
1129
            "wo_kernel",
1130
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1131
            kernel_2_shape,
1132
            self.dtype,
1133
        )
1134
        if not QuantizeConfig.is_fp8_enabled():
1135
            kernel_2 = kernel_2.astype(input_dtype)
1136

1137
        contract_ind = tuple(range(0, len(axis)))
1138

1139
        if self.use_bias:
1140
            bias_1_shape = (num_activations, self.intermediate_dim)
1141
            bias_1 = self.param(
1142
                "wi_bias",
1143
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1144
1145
                bias_1_shape,
                self.dtype,
1146
            ).astype(input_dtype)
1147
1148

            bias_2_shape = (hidden_size,)
1149
            bias_2 = self.param(
1150
                "wo_bias",
1151
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1152
1153
                bias_2_shape,
                self.dtype,
1154
            ).astype(input_dtype)
1155
1156
1157
1158
        else:
            bias_1 = None
            bias_2 = None

1159
        if use_fused_layernorm_mlp:
1160
            out = layernorm_mlp(
1161
1162
1163
1164
1165
1166
1167
1168
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1169
                norm_input_axes=self.layernorm_input_axes,
1170
1171
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1172
1173
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1174
1175
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1176
                activation_type=normalized_acts,
1177
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1178
            )
1179
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1180
1181

        else:  # not use_fused_ln_geglu_mlp
1182
1183
            # DenseGeneral 1
            if fuse_layernorm:
1184
                x = layernorm_dense(
1185
1186
1187
1188
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1189
                    norm_type=self.layernorm_type,
1190
1191
1192
1193
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1194
                    kernel_axes=self.kernel_axes_1,
1195
                    quantizer_set=ffn1_quantizer_set,
1196
                )
1197
            else:
1198
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1199
1200
1201
1202
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1203
1204
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1205
                    quantizer_set=ffn1_quantizer_set,
1206
                )
1207

1208
            if self.enable_low_rank_adaptation:
1209
1210
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1211
1212
                    self.low_rank_adaptation_dim,
                )
1213
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1214
                wi_lora_a_kernel = self.param(
1215
                    "wi_lora_a_kernel",
1216
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1217
                    num_activations,
1218
1219
                    -2,
                    wi_lora_a_kernel_each_shape,
1220
                    self.dtype,
1221
                ).astype(input_dtype)
1222

1223
1224
1225
1226
1227
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1228
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1229
                wi_lora_b_kernel = self.param(
1230
                    "wi_lora_b_kernel",
1231
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1232
                    wi_lora_b_kernel_shape,
1233
                    self.dtype,
1234
                ).astype(input_dtype)
1235

1236
1237
1238
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1239
                    (num_activations, self.intermediate_dim),
1240
1241
1242
1243
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1244

1245
            if self.use_bias:
1246
                x += jnp.reshape(bias_1, bias_1_shape)
1247

1248
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1249
            if is_act_implemented:
1250
                z = activation(x, normalized_acts)
1251
            else:
1252
                activations = []
1253
                x = jnp.split(x, num_activations, axis=-2)
1254
                for idx, act_fn in enumerate(normalized_acts):
1255
1256
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1257
                z = reduce(operator.mul, activations)
1258
                z = jnp.squeeze(z, axis=-2)
1259
            z = z.astype(input_dtype)
1260

1261
1262
1263
1264
1265
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1266

1267
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1268
            z = z.astype(input_dtype)
1269

1270
            # DenseGeneral 2
1271
            out = dense(
1272
1273
1274
1275
1276
1277
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1278
            )
1279

1280
1281
1282
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1283
                wo_lora_a_kernel = self.param(
1284
                    "wo_lora_a_kernel",
1285
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1286
                    wo_lora_a_kernel_shape,
1287
                    self.dtype,
1288
                ).astype(input_dtype)
1289
1290
1291

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1292
                wo_lora_b_kernel = self.param(
1293
                    "wo_lora_b_kernel",
1294
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1295
                    wo_lora_b_kernel_shape,
1296
                    self.dtype,
1297
                ).astype(input_dtype)
1298

1299
1300
1301
1302
1303
1304
1305
1306
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1307

1308
            if self.use_bias:
1309
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1310

1311
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1312

1313
        assert out.dtype == input_dtype
1314
        return out, ln_output  # Output, layner_norm_output