module.py 52.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

Alp Dener's avatar
Alp Dener committed
18
from ..dense import dense, _issue_batch_first_warning as _dense_warning
19
20
21

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
Alp Dener's avatar
Alp Dener committed
22
23
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
24
from ..activation import activation
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
28
29
30
31
32
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
33
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34
from ..sharding import get_non_contracting_logical_axes
35
36
37

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
38
39
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
40
41
42
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
43
44
45
46
47
48
49
50
51
52
53
54
55
56
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


57
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
58
59
60
61
62
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
63
64
65
    return nn.initializers.zeros


66
def _create_layernorm_parameters(
67
    module,
68
69
70
71
72
73
74
75
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
76
):
77
78
79
80
81
82
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
83

84
85
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
86
87
88
89
90
91
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
92
    else:
93
        assert norm_type == "rmsnorm"
94
95
96
97
98
99
100
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
101
    if fn_or_string == "linear":
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
116
117
118
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
119
120
121
122
123
124
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


125
126
127
128
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
129
    hidden_in_names = "ijklm"[: len(axis)]
130
    assert len(features) <= 5
131
132
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
133
134
135
136
137
138
139
140
141

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
142
143
144
145
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
146
147
148
149
150
151

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


152
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
153
154
    r"""
    Applies softmax over a mini-batch of inputs.
155
156
157
158
159
160
161
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
162
163
164
165

    Parameters
    ----------
    scale_factor : float, default = 1.0
166
167
168
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
169
170
171
172
173
174
175
176
177
178
179
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
180
        input_dtype = inputs.dtype
181
182
        logits = inputs

183
184
        # use primitives
        if is_softmax_kernel_available(
185
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
186
        ):
187
            if bias is not None:
188
                logits = logits + bias.astype(input_dtype)
189
190
191
192
193

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

194
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
195
        # use default jax based implementation
196
197
        else:
            if bias is not None:
198
                logits = logits + bias.astype(input_dtype)
199

200
201
202
203
204
205
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
206
            else:
207
208
209
210
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
211
        assert input_dtype == outputs.dtype
212
213
214
        return outputs


215
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
245
        A value added to the denominator of layer normalization for numerical stability.
246
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
247
        Indicate the type of layer normalization.
248
249
250
251
252
253
254
255
256
257
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
258
        Used for initializing scale factors :math:`\gamma`.
259
260
261
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
262
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
263
    scale_axes : Tuple[str, ...], default = ('embed', )
264
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
265
    bias_init : Initializer, default = flax.linen.initializers.zeros
266
267
268
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
269
270
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
271
        only used when :attr:`layernorm_type='layernorm'`.
272
273
274

    Optimization parameters
    -----------------------
275
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
276
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
    transpose_batch_sequence : bool, default = False
278
279
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
280
281
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
282

283
    epsilon: float = 1e-6
284
    layernorm_type: str = "layernorm"
285
286
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
287
    scale_axes: Tuple[str, ...] = ("embed",)
288
    bias_init: Initializer = nn.initializers.zeros
289
    bias_axes: Tuple[str, ...] = ("embed",)
290
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
291
    transpose_batch_sequence: bool = False
292

293
    def __post_init__(self):
294
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
295
296
            self.scale_init,
            self.zero_centered_gamma,
297
        )
298
299
        super().__post_init__()

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
315
        input_dtype = x.dtype
316
317

        features = x.shape[-1]
318
        scale, ln_bias = _create_layernorm_parameters(
319
            self,
320
321
322
323
324
325
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
326
            input_dtype,
327
328
            self.dtype,
        )
329
        out = layernorm(
330
331
332
            x,
            scale,
            ln_bias,
333
            norm_type=self.layernorm_type,
334
335
336
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
337
338
        assert out.dtype == input_dtype
        return out
339
340
341


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
342
343
344
345
    """
    Base class of transformer engine
    """

346
    def generate_quantizer_set(self, postfix: str = ""):
347
        """
348
        Generate a set of FP8 meta for a GEMM.
349
350
        """

351
352
353
354
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
355
356
                jnp.ones,
                (1,),
357
                jnp.float32,
358
359
360
361
362
363
364
365
366
367
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

368
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
369
370
371
372
373
374
375
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
376

377
378
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
379
380
381


class DenseGeneral(TransformerEngineBase):
382
    r"""
383
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
384
385
386
387

    Parameters
    ----------
    features : Union[Iterable[int], int]
388
        The hidden size of each output sample.
389
390
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
391
392
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
393
    kernel_axes : Tuple[str, ...], default = ()
394
        The name of axes used to shard the weights with a corresponding mesh.
395
    use_bias: bool, default = False
396
397
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
398
    bias_init: Initializer, default = flax.linen.initializers.zeros
399
400
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
401
    bias_axes: Tuple[str, ...], default = ()
402
403
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
404
    enable_low_rank_adaptation: bool, default = False
405
        Indicate whether to enable low rank adaptation for each dense layer.
406
407
408
409
410
411
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
412
    axis:  Union[Iterable[int], int], default = -1
413
        An integer tuple with axes to apply the transformation on.
414
415
416
417
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
418
419
    sequence_parallel_output: bool, default = False
        Produce a sequence-parallel output with the first non-batch dimension sharded over
420
421
422

    Optimization parameters
    -----------------------
423
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
424
        The data type used to allocate the initial parameters.
425
    transpose_batch_sequence : bool, default = True
426
427
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
428
429
430
431
432
433
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
434
    use_bias: bool = True
435
436
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
437
438
439
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
440
441
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
442
    transpose_batch_sequence: bool = False
443
    input_axes: Tuple[str, ...] = ()
444
    sequence_parallel_output: bool = False
445
446

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
447
448
449
450
451
452
        if self.transpose_batch_sequence:
            _dense_warning(
                "TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
453
        if self.kernel_init is None:
454
            self.kernel_init = nn.initializers.variance_scaling(
455
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
456
            )
457
458
459
460
461
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
462
        Apply the dense layer transformation to the input.
463
464
465
466
467
468
469
470
471
472
473

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
474

475
        input_dtype = inputs.dtype
476
477
478
479
480
481
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
482
483
484
485
486
487

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
488
489
490
491
492
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
493
        )
494

495
        if not QuantizeConfig.is_fp8_enabled():
496
            kernel = kernel.astype(input_dtype)
497
498

        if self.use_bias:
499
500
501
502
503
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
504
            ).astype(input_dtype)
505
506
507
        else:
            bias = None

508
        quantizer_set = self.generate_quantizer_set()
509
        contract_ind = tuple(range(0, len(axis)))
510
        y = dense(
511
512
513
514
515
516
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
517
            sequence_parallel_output=self.sequence_parallel_output,
518
        )
519

520
        if self.enable_low_rank_adaptation:
521
522
523
524
525
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
526
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
527
            lora_a_kernel = self.param(
528
                "lora_a_kernel",
529
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
530
                lora_a_kernel_shape,
531
                self.dtype,
532
            ).astype(input_dtype)
533
534
535

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
536
            lora_b_kernel = self.param(
537
                "lora_b_kernel",
538
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
539
                lora_b_kernel_shape,
540
                self.dtype,
541
            ).astype(input_dtype)
542

543
544
545
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
546

547
        if bias is not None:
548
549
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
550
551

        assert y.dtype == input_dtype
552
553
554
555
556
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
557
    Applies layer normalization followed by dense layer transformation to the incoming data.
558
559
560
561

    Parameters
    ----------
    features : Union[Iterable[int], int]
562
        The hidden size of each output sample.
563
    enable_layernorm: bool, default = True
564
        Indicate whether to enable layer normalization before dense layer transformation.
565
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
566
        Indicate the type of layer normalization.
567
    epsilon : float, default = 1e-6
568
        A value added to the denominator of layer normalization for numerical stability.
569
570
571
572
573
574
575
576
577
578
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
579
        Used for initializing scale factors :math:`\gamma`.
580
581
582
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
583
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
584
    scale_axes : Tuple[str, ...], default = ('embed', )
585
586
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
587
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
588
589
590
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
591
592
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
593
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
594
595
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
596
597
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
598
    kernel_axes : Tuple[str, ...], default = ()
599
        The name of axes used to shard the weights with a corresponding mesh.
600
    use_bias: bool, default = False
601
602
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
603
    bias_init: Initializer, default = flax.linen.initializers.zeros
604
605
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
606
    bias_axes: Tuple[str, ...], default = ()
607
608
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
609
    return_layernorm_output: bool, default = True
610
        Indicate whether to return the output of layer normalization.
611
        If set False, return None as the second tensor in outputs.
612
    enable_low_rank_adaptation: bool, default = False
613
        Indicate whether to enable low rank adaptation for each dense layer.
614
615
616
617
618
619
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
620
    axis:  Union[Iterable[int], int], default = -1
621
        An integer tuple with axes to apply the transformation on.
622
623
624
625
626
627
628
629
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
630
631
632

    Optimization parameters
    -----------------------
633
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
634
        The data type used to allocate the initial parameters.
635
    transpose_batch_sequence : bool, default = True
636
637
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
638
639
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
640
        The factor to scale the output from `DenseGeneral`. It should be a float
641
642
643
644
645
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
646
    layernorm_type: str = "layernorm"
647
    epsilon: float = 1e-6
648
649
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
650
    scale_axes: Tuple[str, ...] = ("embed",)
651
    ln_bias_init: Initializer = nn.initializers.zeros
652
    ln_bias_axes: Tuple[str, ...] = ("embed",)
653
654
655
656
657
658
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
659
660
661
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
662
663
664
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
665
666
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
667
668
669
    depth_scaling: float = None

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
670
671
672
673
674
675
        if self.transpose_batch_sequence:
            _ln_dense_warning(
                "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
                "inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
                "Use sequence-first inputs at your own discretion."
            )
676
        if self.kernel_init is None:
677
            self.kernel_init = nn.initializers.variance_scaling(
678
679
680
                1.0,
                "fan_in",
                "truncated_normal",
681
                dtype=self.dtype,
682
            )
683
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
684
685
            self.scale_init,
            self.zero_centered_gamma,
686
        )
687
        self.quantizer_set = QuantizerFactory.create_set()
688
689
690
691
692
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
693
        Apply layer normalization to the input followed by a dense layer transformation.
694
695
696
697
698
699
700
701
702
703
704
705

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
706
            If :attr:`return_layernorm_output=False`, then this would be None.
707
        """
708
        assert self.axis == -1, "Only support axis = =-1 at this moment"
709

710
        input_dtype = inputs.dtype
711
712
        ln_output = None

713
714
        quantizer_set = self.generate_quantizer_set()

715
        fuse_layernorm = (
716
            QuantizeConfig.is_fp8_enabled()
717
718
719
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
720
721

        if self.enable_layernorm:
722
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
723
            features = inputs.shape[-1]
724
            scale, ln_bias = _create_layernorm_parameters(
725
                self,
726
727
728
729
730
731
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
732
                input_dtype,
733
734
                self.dtype,
            )
735
736

            if not fuse_layernorm:
737
738
739
740
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
741
                    norm_type=self.layernorm_type,
742
743
744
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

760
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
761
762
763
764
765
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
766
        )
767
        if not QuantizeConfig.is_fp8_enabled():
768
            kernel = kernel.astype(input_dtype)
769
770
771

        contract_ind = tuple(range(0, len(axis)))

772
        if fuse_layernorm:
773
            z = layernorm_dense(
774
775
776
777
                y,
                kernel,
                scale,
                ln_bias,
778
                norm_type=self.layernorm_type,
779
780
781
782
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
783
                kernel_axes=self.kernel_axes,
784
                quantizer_set=quantizer_set,
785
            )
786
        else:
787
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
788
789
790
791
792
793
794
795
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
796

797
        if self.enable_low_rank_adaptation:
798
799
800
801
802
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
803
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
804
            lora_a_kernel = self.param(
805
                "lora_a_kernel",
806
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
807
                lora_a_kernel_shape,
808
                self.dtype,
809
            ).astype(input_dtype)
810
811
812

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
813
            lora_b_kernel = self.param(
814
                "lora_b_kernel",
815
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
816
                lora_b_kernel_shape,
817
                self.dtype,
818
            ).astype(input_dtype)
819

820
821
822
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
823

824
825
        bias = None
        if self.use_bias:
826
827
828
829
830
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
831
            ).astype(input_dtype)
832
833

        if bias is not None:
834
835
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
836
837
838
839

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

840
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
841
        # z = z.reshape(*inputs.shape[: self.axis], *features)
842
        return z, ln_output  # dense_output, layer_norm_output
843
844
845
846
847


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
848
    consisting of 2 successive dense layer transformations, separated by given activations.
849
850
851
852

    Parameters
    ----------
    intermediate_dim: int, default = 2048
853
        Intermediate size to which input samples are projected.
854
    enable_layernorm: bool, default = True
855
        Indicate whether to enable layer normalization before dense layer transformation.
856
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
857
        Indicate the type of layer normalization.
858
    epsilon : float, default = 1e-6
859
        A value added to the denominator of layer normalization for numerical stability.
860
861
862
863
864
865
866
867
868
869
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
870
        Used for initializing scale factors :math:`\gamma`.
871
872
873
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
874
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
875
    scale_axes : Tuple[str, ...], default = ('embed', )
876
877
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
878
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
879
880
881
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
882
883
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
884
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
885
886
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
887
        Used for initializing the weights of both dense layer transformations.
888
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
889
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
890
        The name of axes used to shard the weights with a corresponding mesh for
891
        the weight of the first dense layer transformation.
892
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
893
        The name of axes used to shard the weights with a corresponding mesh for
894
        the weight of the second dense layer transformation.
895
    use_bias: bool, default = False
896
897
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
898
    bias_init: Initializer, default = flax.linen.initializers.zeros
899
900
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
901
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
902
        The name of axes used to shard bias with a corresponding mesh  for
903
        the weight of the first dense layer transformation.
904
        Only used when :attr:`use_bias=True`.
905
    bias_axes_2: Tuple[str, ...], default = ('embed',)
906
        The name of axes used to shard bias with a corresponding mesh  for
907
        the weight of the second dense layer transformation.
908
        Only used when :attr:`use_bias=True`.
909
    return_layernorm_output: bool, default = True
910
        Indicate whether to return the output of layer normalization.
911
912
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
913
        The sequence of activation functions to apply after the first dense layer transformation.
914
        Each activation has its own transformation layer.
915
916
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
917
    intermediate_dropout_rate: float, default = 0.1
918
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
919
920
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
921
    enable_low_rank_adaptation: bool, default = False
922
        Indicate whether to enable low rank adaptation for each dense layer.
923
924
925
926
927
928
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
929
    axis:  Union[Iterable[int], int], default = -1
930
        An integer tuple with axes to apply the transformation on.
931
932
933
934
935
936
937
938
939
940
941
942
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
943
944
945
946
947
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

948
949
950

    Optimization parameters
    -----------------------
951
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
952
        The data type used to allocate the initial parameters.
953
    transpose_batch_sequence : bool, default = True
954
955
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
956
957
958
959
960
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
961
    layernorm_type: str = "layernorm"
962
    epsilon: float = 1e-6
963
964
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
965
    scale_axes: Tuple[str, ...] = ("embed",)
966
    ln_bias_init: Initializer = nn.initializers.zeros
967
    ln_bias_axes: Tuple[str, ...] = ("embed",)
968
    kernel_init: Initializer = None
969
970
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
971
972
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
973
974
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
975
    return_layernorm_output: bool = True
976
977
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
978
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
979
    intermediate_hidden_dropout_dims: Sequence[int] = ()
980
981
982
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
983
984
985
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
986
987
988
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
989
990
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
991
992

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
993
994
995
996
997
998
        if self.transpose_batch_sequence:
            _ln_mlp_warning(
                "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
999
        if self.kernel_init is None:
1000
            self.kernel_init = nn.initializers.variance_scaling(
1001
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1002
            )
1003
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
1004
1005
            self.scale_init,
            self.zero_centered_gamma,
1006
        )
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1027
            If :attr:`return_layernorm_output=False`, then this would be None.
1028
        """
1029
1030
        assert self.axis == -1, "Only support axis == -1 at this moment"

1031
1032
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1033

1034
        input_dtype = inputs.dtype
1035
1036
        ln_output = None

1037
1038
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1039
        fuse_layernorm = (
1040
            QuantizeConfig.is_fp8_enabled()
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1053
        normalized_acts = []
1054
1055
1056
        for act in self.activations:
            if not isinstance(act, str):
                return False
1057
            normalized_acts.append(act.lower())
1058
        normalized_acts = tuple(
1059
1060
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1061

1062
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1063

1064
1065
1066
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1067
1068
        # LayerNorm
        if self.enable_layernorm:
1069
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1070

1071
1072
            features = inputs.shape[-1]

1073
            scale, ln_bias = _create_layernorm_parameters(
1074
                self,
1075
1076
1077
1078
1079
1080
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1081
                input_dtype,
1082
1083
                self.dtype,
            )
1084
1085

            if not fuse_layernorm:
1086
1087
1088
1089
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1090
                    norm_type=self.layernorm_type,
1091
1092
1093
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1108
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1109

1110
        num_activations = len(normalized_acts)
1111
1112
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1113
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1114
        kernel_1 = self.param(
1115
            "wi_kernel",
1116
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1117
1118
1119
            num_activations,
            -2,
            kernel_1_each_shape,
1120
            self.dtype,
1121
        )
1122

1123
        if not QuantizeConfig.is_fp8_enabled():
1124
            kernel_1 = kernel_1.astype(input_dtype)
1125

1126
1127
1128
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1129
        kernel_2 = self.param(
1130
            "wo_kernel",
1131
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1132
            kernel_2_shape,
1133
            self.dtype,
1134
        )
1135
        if not QuantizeConfig.is_fp8_enabled():
1136
            kernel_2 = kernel_2.astype(input_dtype)
1137

1138
        contract_ind = tuple(range(0, len(axis)))
1139

1140
        if self.use_bias:
1141
            bias_1_shape = (num_activations, self.intermediate_dim)
1142
            bias_1 = self.param(
1143
                "wi_bias",
1144
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1145
1146
                bias_1_shape,
                self.dtype,
1147
            ).astype(input_dtype)
1148
1149

            bias_2_shape = (hidden_size,)
1150
            bias_2 = self.param(
1151
                "wo_bias",
1152
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1153
1154
                bias_2_shape,
                self.dtype,
1155
            ).astype(input_dtype)
1156
1157
1158
1159
        else:
            bias_1 = None
            bias_2 = None

1160
        if use_fused_layernorm_mlp:
1161
            out = layernorm_mlp(
1162
1163
1164
1165
1166
1167
1168
1169
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1170
                norm_input_axes=self.layernorm_input_axes,
1171
1172
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1173
1174
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1175
1176
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1177
                activation_type=normalized_acts,
1178
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1179
            )
1180
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1181
1182

        else:  # not use_fused_ln_geglu_mlp
1183
1184
            # DenseGeneral 1
            if fuse_layernorm:
1185
                x = layernorm_dense(
1186
1187
1188
1189
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1190
                    norm_type=self.layernorm_type,
1191
1192
1193
1194
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1195
                    kernel_axes=self.kernel_axes_1,
1196
                    quantizer_set=ffn1_quantizer_set,
1197
                )
1198
            else:
1199
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1200
1201
1202
1203
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1204
1205
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1206
                    quantizer_set=ffn1_quantizer_set,
1207
                )
1208
1209
1210
1211
1212
1213
1214
1215
1216

            if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
                dot_1_output_axes = (
                    *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
                    *get_non_contracting_logical_axes(
                        kernel_1.ndim, self.kernel_axes_1, contract_ind
                    ),
                )
                x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
1217

1218
            if self.enable_low_rank_adaptation:
1219
1220
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1221
1222
                    self.low_rank_adaptation_dim,
                )
1223
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1224
                wi_lora_a_kernel = self.param(
1225
                    "wi_lora_a_kernel",
1226
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1227
                    num_activations,
1228
1229
                    -2,
                    wi_lora_a_kernel_each_shape,
1230
                    self.dtype,
1231
                ).astype(input_dtype)
1232

1233
1234
1235
1236
1237
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1238
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1239
                wi_lora_b_kernel = self.param(
1240
                    "wi_lora_b_kernel",
1241
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1242
                    wi_lora_b_kernel_shape,
1243
                    self.dtype,
1244
                ).astype(input_dtype)
1245

1246
1247
1248
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1249
                    (num_activations, self.intermediate_dim),
1250
1251
1252
1253
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1254

1255
            if self.use_bias:
1256
                x += jnp.reshape(bias_1, bias_1_shape)
1257

1258
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1259
            if is_act_implemented:
1260
                z = activation(x, normalized_acts)
1261
            else:
1262
                activations = []
1263
                x = jnp.split(x, num_activations, axis=-2)
1264
                for idx, act_fn in enumerate(normalized_acts):
1265
1266
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1267
                z = reduce(operator.mul, activations)
1268
                z = jnp.squeeze(z, axis=-2)
1269
            z = z.astype(input_dtype)
1270

1271
1272
1273
1274
1275
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1276

1277
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1278
            z = z.astype(input_dtype)
1279

1280
            # DenseGeneral 2
1281
            out = dense(
1282
1283
1284
1285
1286
1287
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1288
            )
1289

1290
1291
1292
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1293
                wo_lora_a_kernel = self.param(
1294
                    "wo_lora_a_kernel",
1295
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1296
                    wo_lora_a_kernel_shape,
1297
                    self.dtype,
1298
                ).astype(input_dtype)
1299
1300
1301

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1302
                wo_lora_b_kernel = self.param(
1303
                    "wo_lora_b_kernel",
1304
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1305
                    wo_lora_b_kernel_shape,
1306
                    self.dtype,
1307
                ).astype(input_dtype)
1308

1309
1310
1311
1312
1313
1314
1315
1316
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1317

1318
            if self.use_bias:
1319
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1320

1321
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1322

1323
        assert out.dtype == input_dtype
1324
        return out, ln_output  # Output, layner_norm_output