module.py 50.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18
19
from transformer_engine.common import recipe

20
from ..dense import dense
21
22
23

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
24
25
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
26
from ..activation import activation
27
from ..softmax import softmax, SoftmaxType
28
from ..sharding import with_sharding_constraint_by_logical_axes
29
30
31
32
33
34
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
35
36
37
38
39
40
41
42
from ..quantize import (
    QuantizerFactory,
    get_quantize_config,
    QuantizeMeta,
    QuantizeMetaSet,
    ScalingMode,
    TensorSource,
)
43
44
45

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
46
47
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
48
49
50
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
51
52
53
54
55
56
57
58
59
60
61
62
63
64
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


65
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
66
67
68
69
70
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
71
72
73
    return nn.initializers.zeros


74
def _create_layernorm_parameters(
75
    module,
76
77
78
79
80
81
82
83
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
84
):
85
86
87
88
89
90
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
91

92
93
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
94
95
96
97
98
99
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
100
    else:
101
        assert norm_type == "rmsnorm"
102
103
104
105
106
107
108
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
109
    if fn_or_string == "linear":
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
124
125
126
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
127
128
129
130
131
132
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


133
134
135
136
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
137
    hidden_in_names = "ijklm"[: len(axis)]
138
    assert len(features) <= 5
139
140
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
141
142
143
144
145
146
147
148
149

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
150
151
152
153
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
154
155
156
157
158
159

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


160
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
161
162
    r"""
    Applies softmax over a mini-batch of inputs.
163
164
165
166
167
168
169
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
170
171
172
173

    Parameters
    ----------
    scale_factor : float, default = 1.0
174
175
176
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
177
178
179
180
181
182
183
184
185
186
187
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
188
        input_dtype = inputs.dtype
189
190
        logits = inputs

191
192
        # use primitives
        if is_softmax_kernel_available(
193
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
194
        ):
195
            if bias is not None:
196
                logits = logits + bias.astype(input_dtype)
197
198
199
200
201

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

202
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
203
        # use default jax based implementation
204
205
        else:
            if bias is not None:
206
                logits = logits + bias.astype(input_dtype)
207

208
209
210
211
212
213
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
214
            else:
215
216
217
218
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
219
        assert input_dtype == outputs.dtype
220
221
222
        return outputs


223
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
253
        A value added to the denominator of layer normalization for numerical stability.
254
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
255
        Indicate the type of layer normalization.
256
257
258
259
260
261
262
263
264
265
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
266
        Used for initializing scale factors :math:`\gamma`.
267
268
269
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
270
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
271
    scale_axes : Tuple[str, ...], default = ('embed', )
272
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
273
    bias_init : Initializer, default = flax.linen.initializers.zeros
274
275
276
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
277
278
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
279
        only used when :attr:`layernorm_type='layernorm'`.
280
281
282

    Optimization parameters
    -----------------------
283
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
284
        The data type used to allocate the initial parameters.
285
    """
286

287
    epsilon: float = 1e-6
288
    layernorm_type: str = "layernorm"
289
290
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
291
    scale_axes: Tuple[str, ...] = ("embed",)
292
    bias_init: Initializer = nn.initializers.zeros
293
    bias_axes: Tuple[str, ...] = ("embed",)
294
295
    dtype: DType = jnp.float32

296
    def __post_init__(self):
297
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
298
299
            self.scale_init,
            self.zero_centered_gamma,
300
        )
301
302
        super().__post_init__()

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
318
        input_dtype = x.dtype
319
320

        features = x.shape[-1]
321
        scale, ln_bias = _create_layernorm_parameters(
322
            self,
323
324
325
326
327
328
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
329
            input_dtype,
330
331
            self.dtype,
        )
332
        out = layernorm(
333
334
335
            x,
            scale,
            ln_bias,
336
            norm_type=self.layernorm_type,
337
338
339
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
340
341
        assert out.dtype == input_dtype
        return out
342
343
344


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
345
346
347
348
    """
    Base class of transformer engine
    """

349
350
351
    def generate_quantizer_set(
        self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
    ):
352
        """
353
        Generate a set of FP8 meta for a GEMM.
354
355
        """

356
        def generate_quantize_meta(quantizer_name: str):
357
358
359
            collection_name = (
                variable_collection
                if variable_collection is not None
360
                else get_quantize_config().COLLECTION_NAME
361
            )
362
            scale = self.variable(
363
                collection_name,
364
                f"{quantizer_name}{postfix}_scale",
365
366
                jnp.ones,
                (1,),
367
                jnp.float32,
368
369
            ).value
            amax_history = self.variable(
370
                collection_name,
371
372
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
373
                (get_quantize_config().AMAX_HISTORY_LEN,),
374
375
376
377
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

378
379
380
        if get_quantize_config().get_scaling_mode(
            TensorSource.X
        ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
381
382
383
384
385
386
387
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
388

389
        quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
390
        return quantizer_set
391
392
393


class DenseGeneral(TransformerEngineBase):
394
    r"""
395
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
396
397
398
399

    Parameters
    ----------
    features : Union[Iterable[int], int]
400
        The hidden size of each output sample.
401
402
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
403
404
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
405
    kernel_axes : Tuple[str, ...], default = ()
406
        The name of axes used to shard the weights with a corresponding mesh.
407
    use_bias: bool, default = False
408
409
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
410
    bias_init: Initializer, default = flax.linen.initializers.zeros
411
412
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
413
    bias_axes: Tuple[str, ...], default = ()
414
415
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
416
    enable_low_rank_adaptation: bool, default = False
417
        Indicate whether to enable low rank adaptation for each dense layer.
418
419
420
421
422
423
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
424
    axis:  Union[Iterable[int], int], default = -1
425
        An integer tuple with axes to apply the transformation on.
426
427
428
429
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
430
431
432

    Optimization parameters
    -----------------------
433
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
434
        The data type used to allocate the initial parameters.
435
436
437
438
439
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
440
    use_bias: bool = True
441
442
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
443
444
445
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
446
447
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
448
    input_axes: Tuple[str, ...] = ()
449
450
451

    def __post_init__(self):
        if self.kernel_init is None:
452
            self.kernel_init = nn.initializers.variance_scaling(
453
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
454
            )
455
456
457
458
459
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
460
        Apply the dense layer transformation to the input.
461
462
463
464
465
466
467
468
469
470
471

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
472

473
        input_dtype = inputs.dtype
474
475
476
477
478
479
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
480
481
482
483
484
485

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
486
487
488
489
490
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
491
        )
492

493
        if not get_quantize_config().is_fp8_enabled():
494
            kernel = kernel.astype(input_dtype)
495
496

        if self.use_bias:
497
498
499
500
501
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
502
            ).astype(input_dtype)
503
504
505
        else:
            bias = None

506
        quantizer_set = self.generate_quantizer_set()
507
        contract_ind = tuple(range(0, len(axis)))
508
        y = dense(
509
510
511
512
513
514
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
515
        )
516

517
        if self.enable_low_rank_adaptation:
518
519
520
521
522
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
523
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
524
            lora_a_kernel = self.param(
525
                "lora_a_kernel",
526
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
527
                lora_a_kernel_shape,
528
                self.dtype,
529
            ).astype(input_dtype)
530
531
532

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
533
            lora_b_kernel = self.param(
534
                "lora_b_kernel",
535
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
536
                lora_b_kernel_shape,
537
                self.dtype,
538
            ).astype(input_dtype)
539

540
541
542
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
543

544
        if bias is not None:
545
546
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
547
548

        assert y.dtype == input_dtype
549
550
551
552
553
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
554
    Applies layer normalization followed by dense layer transformation to the incoming data.
555
556
557
558

    Parameters
    ----------
    features : Union[Iterable[int], int]
559
        The hidden size of each output sample.
560
    enable_layernorm: bool, default = True
561
        Indicate whether to enable layer normalization before dense layer transformation.
562
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
563
        Indicate the type of layer normalization.
564
    epsilon : float, default = 1e-6
565
        A value added to the denominator of layer normalization for numerical stability.
566
567
568
569
570
571
572
573
574
575
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
576
        Used for initializing scale factors :math:`\gamma`.
577
578
579
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
580
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
581
    scale_axes : Tuple[str, ...], default = ('embed', )
582
583
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
584
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
585
586
587
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
588
589
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
590
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
591
592
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
593
594
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
595
    kernel_axes : Tuple[str, ...], default = ()
596
        The name of axes used to shard the weights with a corresponding mesh.
597
    use_bias: bool, default = False
598
599
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
600
    bias_init: Initializer, default = flax.linen.initializers.zeros
601
602
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
603
    bias_axes: Tuple[str, ...], default = ()
604
605
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
606
    return_layernorm_output: bool, default = True
607
        Indicate whether to return the output of layer normalization.
608
        If set False, return None as the second tensor in outputs.
609
    enable_low_rank_adaptation: bool, default = False
610
        Indicate whether to enable low rank adaptation for each dense layer.
611
612
613
614
615
616
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
617
    axis:  Union[Iterable[int], int], default = -1
618
        An integer tuple with axes to apply the transformation on.
619
620
621
622
623
624
625
626
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
627
628
629

    Optimization parameters
    -----------------------
630
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
631
        The data type used to allocate the initial parameters.
632
    depth_scaling: float, default = None
633
        The factor to scale the output from `DenseGeneral`. It should be a float
634
635
636
637
638
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
639
    layernorm_type: str = "layernorm"
640
    epsilon: float = 1e-6
641
642
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
643
    scale_axes: Tuple[str, ...] = ("embed",)
644
    ln_bias_init: Initializer = nn.initializers.zeros
645
    ln_bias_axes: Tuple[str, ...] = ("embed",)
646
647
648
649
650
651
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
652
653
654
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
655
656
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
657
658
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
659
660
661
662
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
663
            self.kernel_init = nn.initializers.variance_scaling(
664
665
666
                1.0,
                "fan_in",
                "truncated_normal",
667
                dtype=self.dtype,
668
            )
669
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
670
671
            self.scale_init,
            self.zero_centered_gamma,
672
        )
673
        self.quantizer_set = QuantizerFactory.create_set()
674
675
676
677
678
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
679
        Apply layer normalization to the input followed by a dense layer transformation.
680
681
682
683
684
685
686
687
688
689
690
691

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
692
            If :attr:`return_layernorm_output=False`, then this would be None.
693
        """
694
        assert self.axis == -1, "Only support axis = =-1 at this moment"
695

696
        input_dtype = inputs.dtype
697
698
        ln_output = None

699
700
        quantizer_set = self.generate_quantizer_set()

701
        fuse_layernorm = (
702
            get_quantize_config().is_fp8_enabled()
703
704
705
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
706
707

        if self.enable_layernorm:
708
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
709
            features = inputs.shape[-1]
710
            scale, ln_bias = _create_layernorm_parameters(
711
                self,
712
713
714
715
716
717
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
718
                input_dtype,
719
720
                self.dtype,
            )
721
722

            if not fuse_layernorm:
723
724
725
726
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
727
                    norm_type=self.layernorm_type,
728
729
730
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

746
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
747
748
749
750
751
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
752
        )
753
        if not get_quantize_config().is_fp8_enabled():
754
            kernel = kernel.astype(input_dtype)
755
756
757

        contract_ind = tuple(range(0, len(axis)))

758
        if fuse_layernorm:
759
            z = layernorm_dense(
760
761
762
763
                y,
                kernel,
                scale,
                ln_bias,
764
                norm_type=self.layernorm_type,
765
766
767
768
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
769
                kernel_axes=self.kernel_axes,
770
                quantizer_set=quantizer_set,
771
            )
772
        else:
773
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
774
775
776
777
778
779
780
781
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
782

783
        if self.enable_low_rank_adaptation:
784
785
786
787
788
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
789
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
790
            lora_a_kernel = self.param(
791
                "lora_a_kernel",
792
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
793
                lora_a_kernel_shape,
794
                self.dtype,
795
            ).astype(input_dtype)
796
797
798

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
799
            lora_b_kernel = self.param(
800
                "lora_b_kernel",
801
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
802
                lora_b_kernel_shape,
803
                self.dtype,
804
            ).astype(input_dtype)
805

806
807
808
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
809

810
811
        bias = None
        if self.use_bias:
812
813
814
815
816
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
817
            ).astype(input_dtype)
818
819

        if bias is not None:
820
821
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
822
823
824
825

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

826
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
827
        # z = z.reshape(*inputs.shape[: self.axis], *features)
828
        return z, ln_output  # dense_output, layer_norm_output
829
830
831
832
833


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
834
    consisting of 2 successive dense layer transformations, separated by given activations.
835
836
837
838

    Parameters
    ----------
    intermediate_dim: int, default = 2048
839
        Intermediate size to which input samples are projected.
840
    enable_layernorm: bool, default = True
841
        Indicate whether to enable layer normalization before dense layer transformation.
842
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
843
        Indicate the type of layer normalization.
844
    epsilon : float, default = 1e-6
845
        A value added to the denominator of layer normalization for numerical stability.
846
847
848
849
850
851
852
853
854
855
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
856
        Used for initializing scale factors :math:`\gamma`.
857
858
859
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
860
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
861
    scale_axes : Tuple[str, ...], default = ('embed', )
862
863
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
864
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
865
866
867
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
868
869
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
870
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
871
872
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
873
        Used for initializing the weights of both dense layer transformations.
874
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
875
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
876
        The name of axes used to shard the weights with a corresponding mesh for
877
        the weight of the first dense layer transformation.
878
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
879
        The name of axes used to shard the weights with a corresponding mesh for
880
        the weight of the second dense layer transformation.
881
    use_bias: bool, default = False
882
883
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
884
    bias_init: Initializer, default = flax.linen.initializers.zeros
885
886
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
887
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
888
        The name of axes used to shard bias with a corresponding mesh  for
889
        the weight of the first dense layer transformation.
890
        Only used when :attr:`use_bias=True`.
891
    bias_axes_2: Tuple[str, ...], default = ('embed',)
892
        The name of axes used to shard bias with a corresponding mesh  for
893
        the weight of the second dense layer transformation.
894
        Only used when :attr:`use_bias=True`.
895
    return_layernorm_output: bool, default = True
896
        Indicate whether to return the output of layer normalization.
897
898
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
899
        The sequence of activation functions to apply after the first dense layer transformation.
900
        Each activation has its own transformation layer.
901
902
903
904
    activation_params: dict, default = None
        The parameters needed(if any) by the activation functions specified in :attr:`activations`.
        At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
        need additional parameters.
905
906
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
907
    intermediate_dropout_rate: float, default = 0.1
908
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
909
910
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
911
    enable_low_rank_adaptation: bool, default = False
912
        Indicate whether to enable low rank adaptation for each dense layer.
913
914
915
916
917
918
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
919
    axis:  Union[Iterable[int], int], default = -1
920
        An integer tuple with axes to apply the transformation on.
921
922
923
924
925
926
927
928
929
930
931
932
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
933
934
935
936
937
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

938
939
940

    Optimization parameters
    -----------------------
941
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
942
        The data type used to allocate the initial parameters.
943
944
945
946
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
947
    layernorm_type: str = "layernorm"
948
    epsilon: float = 1e-6
949
950
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
951
    scale_axes: Tuple[str, ...] = ("embed",)
952
    ln_bias_init: Initializer = nn.initializers.zeros
953
    ln_bias_axes: Tuple[str, ...] = ("embed",)
954
    kernel_init: Initializer = None
955
956
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
957
958
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
959
960
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
961
    return_layernorm_output: bool = True
962
    activations: Sequence[Union[str, Callable]] = ("relu",)
963
    activation_params: dict = None
964
    intermediate_dropout_rng_name: str = "dropout"
965
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
966
    intermediate_hidden_dropout_dims: Sequence[int] = ()
967
968
969
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
970
971
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
972
973
974
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
975
976
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
977
978
979

    def __post_init__(self):
        if self.kernel_init is None:
980
            self.kernel_init = nn.initializers.variance_scaling(
981
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
982
            )
983
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
984
985
            self.scale_init,
            self.zero_centered_gamma,
986
        )
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1007
            If :attr:`return_layernorm_output=False`, then this would be None.
1008
        """
1009
1010
        assert self.axis == -1, "Only support axis == -1 at this moment"

1011
1012
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1013

1014
        input_dtype = inputs.dtype
1015
1016
        ln_output = None

1017
1018
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1019
        fuse_layernorm = (
1020
            get_quantize_config().is_fp8_enabled()
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
1031
            ("clamped_silu", "clamped_linear"),
1032
1033
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1034
        normalized_acts = []
1035
1036
1037
        for act in self.activations:
            if not isinstance(act, str):
                return False
1038
            normalized_acts.append(act.lower())
1039
        normalized_acts = tuple(
1040
1041
1042
            reversed(normalized_acts)
            if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
            else normalized_acts
1043
        )
1044

1045
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1046

1047
1048
1049
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1050
1051
        # LayerNorm
        if self.enable_layernorm:
1052
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1053

1054
1055
            features = inputs.shape[-1]

1056
            scale, ln_bias = _create_layernorm_parameters(
1057
                self,
1058
1059
1060
1061
1062
1063
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1064
                input_dtype,
1065
1066
                self.dtype,
            )
1067
1068

            if not fuse_layernorm:
1069
1070
1071
1072
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1073
                    norm_type=self.layernorm_type,
1074
1075
1076
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1091
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1092

1093
        num_activations = len(normalized_acts)
1094
1095
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1096
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1097
        kernel_1 = self.param(
1098
            "wi_kernel",
1099
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1100
1101
1102
            num_activations,
            -2,
            kernel_1_each_shape,
1103
            self.dtype,
1104
        )
1105

1106
        if not get_quantize_config().is_fp8_enabled():
1107
            kernel_1 = kernel_1.astype(input_dtype)
1108

1109
1110
1111
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1112
        kernel_2 = self.param(
1113
            "wo_kernel",
1114
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1115
            kernel_2_shape,
1116
            self.dtype,
1117
        )
1118
        if not get_quantize_config().is_fp8_enabled():
1119
            kernel_2 = kernel_2.astype(input_dtype)
1120

1121
        contract_ind = tuple(range(0, len(axis)))
1122

1123
        if self.use_bias:
1124
            bias_1_shape = (num_activations, self.intermediate_dim)
1125
            bias_1 = self.param(
1126
                "wi_bias",
1127
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1128
1129
                bias_1_shape,
                self.dtype,
1130
            ).astype(input_dtype)
1131
1132

            bias_2_shape = (hidden_size,)
1133
            bias_2 = self.param(
1134
                "wo_bias",
1135
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1136
1137
                bias_2_shape,
                self.dtype,
1138
            ).astype(input_dtype)
1139
1140
1141
1142
        else:
            bias_1 = None
            bias_2 = None

1143
        if use_fused_layernorm_mlp:
1144
            out = layernorm_mlp(
1145
1146
1147
1148
1149
1150
1151
1152
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1153
                norm_input_axes=self.layernorm_input_axes,
1154
1155
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1156
1157
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1158
1159
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1160
                activation_type=normalized_acts,
1161
                activation_params=self.activation_params,
1162
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1163
            )
1164
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1165
1166

        else:  # not use_fused_ln_geglu_mlp
1167
1168
            # DenseGeneral 1
            if fuse_layernorm:
1169
                x = layernorm_dense(
1170
1171
1172
1173
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1174
                    norm_type=self.layernorm_type,
1175
1176
1177
1178
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1179
                    kernel_axes=self.kernel_axes_1,
1180
                    quantizer_set=ffn1_quantizer_set,
1181
                )
1182
            else:
1183
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1184
1185
1186
1187
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1188
1189
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1190
                    quantizer_set=ffn1_quantizer_set,
1191
                )
1192

1193
            if self.enable_low_rank_adaptation:
1194
1195
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1196
1197
                    self.low_rank_adaptation_dim,
                )
1198
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1199
                wi_lora_a_kernel = self.param(
1200
                    "wi_lora_a_kernel",
1201
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1202
                    num_activations,
1203
1204
                    -2,
                    wi_lora_a_kernel_each_shape,
1205
                    self.dtype,
1206
                ).astype(input_dtype)
1207

1208
1209
1210
1211
1212
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1213
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1214
                wi_lora_b_kernel = self.param(
1215
                    "wi_lora_b_kernel",
1216
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1217
                    wi_lora_b_kernel_shape,
1218
                    self.dtype,
1219
                ).astype(input_dtype)
1220

1221
1222
1223
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1224
                    (num_activations, self.intermediate_dim),
1225
1226
1227
1228
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1229

1230
            if self.use_bias:
1231
                x += jnp.reshape(bias_1, bias_1_shape)
1232

1233
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1234
            if is_act_implemented:
1235
                z = activation(x, normalized_acts)
1236
            else:
1237
                activations = []
1238
                x = jnp.split(x, num_activations, axis=-2)
1239
                for idx, act_fn in enumerate(normalized_acts):
1240
1241
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1242
                z = reduce(operator.mul, activations)
1243
                z = jnp.squeeze(z, axis=-2)
1244
            z = z.astype(input_dtype)
1245

1246
1247
1248
1249
1250
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1251

1252
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1253
            z = z.astype(input_dtype)
1254

1255
            # DenseGeneral 2
1256
            out = dense(
1257
1258
1259
1260
1261
1262
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1263
            )
1264

1265
1266
1267
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1268
                wo_lora_a_kernel = self.param(
1269
                    "wo_lora_a_kernel",
1270
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1271
                    wo_lora_a_kernel_shape,
1272
                    self.dtype,
1273
                ).astype(input_dtype)
1274
1275
1276

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1277
                wo_lora_b_kernel = self.param(
1278
                    "wo_lora_b_kernel",
1279
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1280
                    wo_lora_b_kernel_shape,
1281
                    self.dtype,
1282
                ).astype(input_dtype)
1283

1284
1285
1286
1287
1288
1289
1290
1291
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1292

1293
            if self.use_bias:
1294
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1295

1296
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1297

1298
        assert out.dtype == input_dtype
1299
        return out, ln_output  # Output, layer_norm_output