module.py 51.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18
19
from transformer_engine.common import recipe

20
from ..dense import dense
21
22
23

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
24
25
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
26
from ..activation import activation
27
from ..softmax import softmax, SoftmaxType
28
from ..sharding import with_sharding_constraint_by_logical_axes
29
30
31
32
33
34
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
35
36
37
38
39
40
41
42
from ..quantize import (
    QuantizerFactory,
    get_quantize_config,
    QuantizeMeta,
    QuantizeMetaSet,
    ScalingMode,
    TensorSource,
)
43
44
45

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
46
47
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
48
49
50
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
51
52
53
54
55
56
57
58
59
60
61
62
63
64
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


65
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
66
67
68
69
70
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
71
72
73
    return nn.initializers.zeros


74
def _create_layernorm_parameters(
75
    module,
76
77
78
79
80
81
82
83
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
84
):
85
86
87
88
89
90
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
91

92
93
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
94
95
96
97
98
99
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
100
    else:
101
        assert norm_type == "rmsnorm"
102
103
104
105
106
107
108
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
109
    if fn_or_string == "linear":
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
124
125
126
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
127
128
129
130
131
132
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


133
134
135
136
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
137
    hidden_in_names = "ijklm"[: len(axis)]
138
    assert len(features) <= 5
139
140
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
141
142
143
144
145
146
147
148
149

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
150
151
152
153
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
154
155
156
157
158
159

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


160
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
161
162
    r"""
    Applies softmax over a mini-batch of inputs.
163
164
165
166
167
168
169
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
170
171
172
173

    Parameters
    ----------
    scale_factor : float, default = 1.0
174
175
176
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
177
178
179
180
181
182
183
184
185
186
187
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
188
        input_dtype = inputs.dtype
189
190
        logits = inputs

191
192
        # use primitives
        if is_softmax_kernel_available(
193
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
194
        ):
195
            if bias is not None:
196
                logits = logits + bias.astype(input_dtype)
197
198
199
200
201

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

202
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
203
        # use default jax based implementation
204
205
        else:
            if bias is not None:
206
                logits = logits + bias.astype(input_dtype)
207

208
209
210
211
212
213
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
214
            else:
215
216
217
218
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
219
        assert input_dtype == outputs.dtype
220
221
222
        return outputs


223
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
253
        A value added to the denominator of layer normalization for numerical stability.
254
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
255
        Indicate the type of layer normalization.
256
257
258
259
260
261
262
263
264
265
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
266
        Used for initializing scale factors :math:`\gamma`.
267
268
269
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
270
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
271
    scale_axes : Tuple[str, ...], default = ('embed', )
272
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
273
    bias_init : Initializer, default = flax.linen.initializers.zeros
274
275
276
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
277
278
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
279
        only used when :attr:`layernorm_type='layernorm'`.
280
281
282

    Optimization parameters
    -----------------------
283
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
284
        The data type used to allocate the initial parameters.
285
    """
286

287
    epsilon: float = 1e-6
288
    layernorm_type: str = "layernorm"
289
290
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
291
    scale_axes: Tuple[str, ...] = ("embed",)
292
    bias_init: Initializer = nn.initializers.zeros
293
    bias_axes: Tuple[str, ...] = ("embed",)
294
295
    dtype: DType = jnp.float32

296
    def __post_init__(self):
297
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
298
299
            self.scale_init,
            self.zero_centered_gamma,
300
        )
301
302
        super().__post_init__()

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
318
        input_dtype = x.dtype
319
320

        features = x.shape[-1]
321
        scale, ln_bias = _create_layernorm_parameters(
322
            self,
323
324
325
326
327
328
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
329
            input_dtype,
330
331
            self.dtype,
        )
332
        out = layernorm(
333
334
335
            x,
            scale,
            ln_bias,
336
            norm_type=self.layernorm_type,
337
338
339
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
340
341
        assert out.dtype == input_dtype
        return out
342
343
344


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
345
346
347
348
    """
    Base class of transformer engine
    """

349
350
351
    def generate_quantizer_set(
        self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
    ):
352
        """
353
        Generate a set of FP8 meta for a GEMM.
354
355
        """

356
        def generate_quantize_meta(quantizer_name: str):
357
358
359
            collection_name = (
                variable_collection
                if variable_collection is not None
360
                else get_quantize_config().COLLECTION_NAME
361
            )
362
            scale = self.variable(
363
                collection_name,
364
                f"{quantizer_name}{postfix}_scale",
365
366
                jnp.ones,
                (1,),
367
                jnp.float32,
368
369
            ).value
            amax_history = self.variable(
370
                collection_name,
371
372
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
373
                (get_quantize_config().AMAX_HISTORY_LEN,),
374
375
376
377
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

378
379
380
        if get_quantize_config().get_scaling_mode(
            TensorSource.X
        ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
381
382
383
384
385
386
387
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
388

389
        quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
390
        return quantizer_set
391
392
393


class DenseGeneral(TransformerEngineBase):
394
    r"""
395
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
396
397
398
399

    Parameters
    ----------
    features : Union[Iterable[int], int]
400
        The hidden size of each output sample.
401
402
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
403
404
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
405
    kernel_axes : Tuple[str, ...], default = ()
406
        The name of axes used to shard the weights with a corresponding mesh.
407
    use_bias: bool, default = False
408
409
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
410
    bias_init: Initializer, default = flax.linen.initializers.zeros
411
412
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
413
    bias_axes: Tuple[str, ...], default = ()
414
415
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
416
    enable_low_rank_adaptation: bool, default = False
417
        Indicate whether to enable low rank adaptation for each dense layer.
418
419
420
421
422
423
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
424
    axis:  Union[Iterable[int], int], default = -1
425
        An integer tuple with axes to apply the transformation on.
426
427
428
429
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
430
431
432

    Optimization parameters
    -----------------------
433
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
434
        The data type used to allocate the initial parameters.
435
436
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
437
438
439
440
441
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
442
    use_bias: bool = True
443
444
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
445
446
447
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
448
449
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
450
    input_axes: Tuple[str, ...] = ()
451
    transpose_batch_sequence: bool = False
452
453
454

    def __post_init__(self):
        if self.kernel_init is None:
455
            self.kernel_init = nn.initializers.variance_scaling(
456
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
457
            )
458
459
460
461
462
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
463
        Apply the dense layer transformation to the input.
464
465
466
467
468
469
470
471
472
473
474

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
475

476
        input_dtype = inputs.dtype
477
478
479
480
481
482
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
483
484
485
486
487
488

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
489
490
491
492
493
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
494
        )
495

496
        if not get_quantize_config().is_fp8_enabled():
497
            kernel = kernel.astype(input_dtype)
498
499

        if self.use_bias:
500
501
502
503
504
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
505
            ).astype(input_dtype)
506
507
508
        else:
            bias = None

509
        quantizer_set = self.generate_quantizer_set()
510
        contract_ind = tuple(range(0, len(axis)))
511
        y = dense(
512
513
514
515
516
517
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
518
            transpose_batch_sequence=self.transpose_batch_sequence,
519
        )
520

521
        if self.enable_low_rank_adaptation:
522
523
524
525
526
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
527
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
528
            lora_a_kernel = self.param(
529
                "lora_a_kernel",
530
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
531
                lora_a_kernel_shape,
532
                self.dtype,
533
            ).astype(input_dtype)
534
535
536

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
537
            lora_b_kernel = self.param(
538
                "lora_b_kernel",
539
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
540
                lora_b_kernel_shape,
541
                self.dtype,
542
            ).astype(input_dtype)
543

544
545
546
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
547

548
        if bias is not None:
549
550
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
551
552

        assert y.dtype == input_dtype
553
554
555
556
557
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
558
    Applies layer normalization followed by dense layer transformation to the incoming data.
559
560
561
562

    Parameters
    ----------
    features : Union[Iterable[int], int]
563
        The hidden size of each output sample.
564
    enable_layernorm: bool, default = True
565
        Indicate whether to enable layer normalization before dense layer transformation.
566
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
567
        Indicate the type of layer normalization.
568
    epsilon : float, default = 1e-6
569
        A value added to the denominator of layer normalization for numerical stability.
570
571
572
573
574
575
576
577
578
579
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
580
        Used for initializing scale factors :math:`\gamma`.
581
582
583
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
584
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
585
    scale_axes : Tuple[str, ...], default = ('embed', )
586
587
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
588
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
589
590
591
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
592
593
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
594
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
595
596
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
597
598
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
599
    kernel_axes : Tuple[str, ...], default = ()
600
        The name of axes used to shard the weights with a corresponding mesh.
601
    use_bias: bool, default = False
602
603
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
604
    bias_init: Initializer, default = flax.linen.initializers.zeros
605
606
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
607
    bias_axes: Tuple[str, ...], default = ()
608
609
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
610
    return_layernorm_output: bool, default = True
611
        Indicate whether to return the output of layer normalization.
612
        If set False, return None as the second tensor in outputs.
613
    enable_low_rank_adaptation: bool, default = False
614
        Indicate whether to enable low rank adaptation for each dense layer.
615
616
617
618
619
620
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
621
    axis:  Union[Iterable[int], int], default = -1
622
        An integer tuple with axes to apply the transformation on.
623
624
625
626
627
628
629
630
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
631
632
633

    Optimization parameters
    -----------------------
634
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
635
        The data type used to allocate the initial parameters.
636
    depth_scaling: float, default = None
637
        The factor to scale the output from `DenseGeneral`. It should be a float
638
        value or None. When None is set, then no scaling is applied.
639
640
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
641
642
643
644
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
645
    layernorm_type: str = "layernorm"
646
    epsilon: float = 1e-6
647
648
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
649
    scale_axes: Tuple[str, ...] = ("embed",)
650
    ln_bias_init: Initializer = nn.initializers.zeros
651
    ln_bias_axes: Tuple[str, ...] = ("embed",)
652
653
654
655
656
657
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
658
659
660
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
661
662
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
663
664
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
665
    depth_scaling: float = None
666
    transpose_batch_sequence: bool = False
667
668
669

    def __post_init__(self):
        if self.kernel_init is None:
670
            self.kernel_init = nn.initializers.variance_scaling(
671
672
673
                1.0,
                "fan_in",
                "truncated_normal",
674
                dtype=self.dtype,
675
            )
676
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
677
678
            self.scale_init,
            self.zero_centered_gamma,
679
        )
680
        self.quantizer_set = QuantizerFactory.create_set()
681
682
683
684
685
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
686
        Apply layer normalization to the input followed by a dense layer transformation.
687
688
689
690
691
692
693
694
695
696
697
698

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
699
            If :attr:`return_layernorm_output=False`, then this would be None.
700
        """
701
        assert self.axis == -1, "Only support axis = =-1 at this moment"
702

703
        input_dtype = inputs.dtype
704
705
        ln_output = None

706
707
        quantizer_set = self.generate_quantizer_set()

708
        fuse_layernorm = (
709
            get_quantize_config().is_fp8_enabled()
710
711
712
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
713
714

        if self.enable_layernorm:
715
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
716
            features = inputs.shape[-1]
717
            scale, ln_bias = _create_layernorm_parameters(
718
                self,
719
720
721
722
723
724
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
725
                input_dtype,
726
727
                self.dtype,
            )
728
729

            if not fuse_layernorm:
730
731
732
733
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
734
                    norm_type=self.layernorm_type,
735
736
737
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

753
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
754
755
756
757
758
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
759
        )
760
        if not get_quantize_config().is_fp8_enabled():
761
            kernel = kernel.astype(input_dtype)
762
763
764

        contract_ind = tuple(range(0, len(axis)))

765
        if fuse_layernorm:
766
            z = layernorm_dense(
767
768
769
770
                y,
                kernel,
                scale,
                ln_bias,
771
                norm_type=self.layernorm_type,
772
773
774
775
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
776
                kernel_axes=self.kernel_axes,
777
                quantizer_set=quantizer_set,
778
                transpose_batch_sequence=self.transpose_batch_sequence,
779
            )
780
        else:
781
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
782
783
784
785
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
786
                transpose_batch_sequence=self.transpose_batch_sequence,
787
788
789
790
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
791

792
        if self.enable_low_rank_adaptation:
793
794
795
796
797
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
798
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
799
            lora_a_kernel = self.param(
800
                "lora_a_kernel",
801
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
802
                lora_a_kernel_shape,
803
                self.dtype,
804
            ).astype(input_dtype)
805
806
807

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
808
            lora_b_kernel = self.param(
809
                "lora_b_kernel",
810
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
811
                lora_b_kernel_shape,
812
                self.dtype,
813
            ).astype(input_dtype)
814

815
816
817
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
818

819
820
        bias = None
        if self.use_bias:
821
822
823
824
825
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
826
            ).astype(input_dtype)
827
828

        if bias is not None:
829
830
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
831
832
833
834

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

835
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
836
        # z = z.reshape(*inputs.shape[: self.axis], *features)
837
        return z, ln_output  # dense_output, layer_norm_output
838
839
840
841
842


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
843
    consisting of 2 successive dense layer transformations, separated by given activations.
844
845
846
847

    Parameters
    ----------
    intermediate_dim: int, default = 2048
848
        Intermediate size to which input samples are projected.
849
    enable_layernorm: bool, default = True
850
        Indicate whether to enable layer normalization before dense layer transformation.
851
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
852
        Indicate the type of layer normalization.
853
    epsilon : float, default = 1e-6
854
        A value added to the denominator of layer normalization for numerical stability.
855
856
857
858
859
860
861
862
863
864
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
865
        Used for initializing scale factors :math:`\gamma`.
866
867
868
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
869
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
870
    scale_axes : Tuple[str, ...], default = ('embed', )
871
872
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
873
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
874
875
876
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
877
878
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
879
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
880
881
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
882
        Used for initializing the weights of both dense layer transformations.
883
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
884
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
885
        The name of axes used to shard the weights with a corresponding mesh for
886
        the weight of the first dense layer transformation.
887
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
888
        The name of axes used to shard the weights with a corresponding mesh for
889
        the weight of the second dense layer transformation.
890
    use_bias: bool, default = False
891
892
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
893
    bias_init: Initializer, default = flax.linen.initializers.zeros
894
895
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
896
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
897
        The name of axes used to shard bias with a corresponding mesh  for
898
        the weight of the first dense layer transformation.
899
        Only used when :attr:`use_bias=True`.
900
    bias_axes_2: Tuple[str, ...], default = ('embed',)
901
        The name of axes used to shard bias with a corresponding mesh  for
902
        the weight of the second dense layer transformation.
903
        Only used when :attr:`use_bias=True`.
904
    return_layernorm_output: bool, default = True
905
        Indicate whether to return the output of layer normalization.
906
907
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
908
        The sequence of activation functions to apply after the first dense layer transformation.
909
        Each activation has its own transformation layer.
910
911
912
913
    activation_params: dict, default = None
        The parameters needed(if any) by the activation functions specified in :attr:`activations`.
        At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
        need additional parameters.
914
915
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
916
    intermediate_dropout_rate: float, default = 0.1
917
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
918
919
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
920
    enable_low_rank_adaptation: bool, default = False
921
        Indicate whether to enable low rank adaptation for each dense layer.
922
923
924
925
926
927
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
928
    axis:  Union[Iterable[int], int], default = -1
929
        An integer tuple with axes to apply the transformation on.
930
931
932
933
934
935
936
937
938
939
940
941
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
942
943
944
945
946
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

947
948
949

    Optimization parameters
    -----------------------
950
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
951
        The data type used to allocate the initial parameters.
952
953
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
954
955
956
957
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
958
    layernorm_type: str = "layernorm"
959
    epsilon: float = 1e-6
960
961
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
962
    scale_axes: Tuple[str, ...] = ("embed",)
963
    ln_bias_init: Initializer = nn.initializers.zeros
964
    ln_bias_axes: Tuple[str, ...] = ("embed",)
965
    kernel_init: Initializer = None
966
967
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
968
969
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
970
971
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
972
    return_layernorm_output: bool = True
973
    activations: Sequence[Union[str, Callable]] = ("relu",)
974
    activation_params: dict = None
975
    intermediate_dropout_rng_name: str = "dropout"
976
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
977
    intermediate_hidden_dropout_dims: Sequence[int] = ()
978
979
980
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
981
982
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
983
984
985
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
986
987
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
988
    transpose_batch_sequence: bool = False
989
990
991

    def __post_init__(self):
        if self.kernel_init is None:
992
            self.kernel_init = nn.initializers.variance_scaling(
993
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
994
            )
995
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
996
997
            self.scale_init,
            self.zero_centered_gamma,
998
        )
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1019
            If :attr:`return_layernorm_output=False`, then this would be None.
1020
        """
1021
1022
        assert self.axis == -1, "Only support axis == -1 at this moment"

1023
1024
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1025

1026
        input_dtype = inputs.dtype
1027
1028
        ln_output = None

1029
1030
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1031
        fuse_layernorm = (
1032
            get_quantize_config().is_fp8_enabled()
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
1043
            ("clamped_silu", "clamped_linear"),
1044
1045
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1046
        normalized_acts = []
1047
1048
1049
        for act in self.activations:
            if not isinstance(act, str):
                return False
1050
            normalized_acts.append(act.lower())
1051
        normalized_acts = tuple(
1052
1053
1054
            reversed(normalized_acts)
            if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
            else normalized_acts
1055
        )
1056

1057
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1058

1059
1060
1061
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1062
1063
        # LayerNorm
        if self.enable_layernorm:
1064
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1065

1066
1067
            features = inputs.shape[-1]

1068
            scale, ln_bias = _create_layernorm_parameters(
1069
                self,
1070
1071
1072
1073
1074
1075
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1076
                input_dtype,
1077
1078
                self.dtype,
            )
1079
1080

            if not fuse_layernorm:
1081
1082
1083
1084
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1085
                    norm_type=self.layernorm_type,
1086
1087
1088
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1103
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1104

1105
        num_activations = len(normalized_acts)
1106
1107
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1108
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1109
        kernel_1 = self.param(
1110
            "wi_kernel",
1111
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1112
1113
1114
            num_activations,
            -2,
            kernel_1_each_shape,
1115
            self.dtype,
1116
        )
1117

1118
        if not get_quantize_config().is_fp8_enabled():
1119
            kernel_1 = kernel_1.astype(input_dtype)
1120

1121
1122
1123
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1124
        kernel_2 = self.param(
1125
            "wo_kernel",
1126
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1127
            kernel_2_shape,
1128
            self.dtype,
1129
        )
1130
        if not get_quantize_config().is_fp8_enabled():
1131
            kernel_2 = kernel_2.astype(input_dtype)
1132

1133
        contract_ind = tuple(range(0, len(axis)))
1134

1135
        if self.use_bias:
1136
            bias_1_shape = (num_activations, self.intermediate_dim)
1137
            bias_1 = self.param(
1138
                "wi_bias",
1139
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1140
1141
                bias_1_shape,
                self.dtype,
1142
            ).astype(input_dtype)
1143
1144

            bias_2_shape = (hidden_size,)
1145
            bias_2 = self.param(
1146
                "wo_bias",
1147
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1148
1149
                bias_2_shape,
                self.dtype,
1150
            ).astype(input_dtype)
1151
1152
1153
1154
        else:
            bias_1 = None
            bias_2 = None

1155
        if use_fused_layernorm_mlp:
1156
            out = layernorm_mlp(
1157
1158
1159
1160
1161
1162
1163
1164
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1165
                norm_input_axes=self.layernorm_input_axes,
1166
1167
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1168
1169
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1170
1171
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1172
                activation_type=normalized_acts,
1173
                activation_params=self.activation_params,
1174
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1175
                transpose_batch_sequence=self.transpose_batch_sequence,
1176
            )
1177
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1178
1179

        else:  # not use_fused_ln_geglu_mlp
1180
1181
            # DenseGeneral 1
            if fuse_layernorm:
1182
                x = layernorm_dense(
1183
1184
1185
1186
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1187
                    norm_type=self.layernorm_type,
1188
1189
1190
1191
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1192
                    kernel_axes=self.kernel_axes_1,
1193
                    quantizer_set=ffn1_quantizer_set,
1194
                    transpose_batch_sequence=self.transpose_batch_sequence,
1195
                )
1196
            else:
1197
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1198
1199
1200
1201
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1202
1203
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1204
                    quantizer_set=ffn1_quantizer_set,
1205
                    transpose_batch_sequence=self.transpose_batch_sequence,
1206
                )
1207

1208
            if self.enable_low_rank_adaptation:
1209
1210
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1211
1212
                    self.low_rank_adaptation_dim,
                )
1213
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1214
                wi_lora_a_kernel = self.param(
1215
                    "wi_lora_a_kernel",
1216
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1217
                    num_activations,
1218
1219
                    -2,
                    wi_lora_a_kernel_each_shape,
1220
                    self.dtype,
1221
                ).astype(input_dtype)
1222

1223
1224
1225
1226
1227
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1228
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1229
                wi_lora_b_kernel = self.param(
1230
                    "wi_lora_b_kernel",
1231
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1232
                    wi_lora_b_kernel_shape,
1233
                    self.dtype,
1234
                ).astype(input_dtype)
1235

1236
1237
1238
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1239
                    (num_activations, self.intermediate_dim),
1240
1241
1242
1243
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1244

1245
            if self.use_bias:
1246
                x += jnp.reshape(bias_1, bias_1_shape)
1247

1248
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1249
            if is_act_implemented:
1250
                z = activation(x, normalized_acts)
1251
            else:
1252
                activations = []
1253
                x = jnp.split(x, num_activations, axis=-2)
1254
                for idx, act_fn in enumerate(normalized_acts):
1255
1256
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1257
                z = reduce(operator.mul, activations)
1258
                z = jnp.squeeze(z, axis=-2)
1259
            z = z.astype(input_dtype)
1260

1261
1262
1263
1264
1265
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1266

1267
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1268
            z = z.astype(input_dtype)
1269

1270
            # DenseGeneral 2
1271
            out = dense(
1272
1273
1274
1275
1276
1277
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1278
                transpose_batch_sequence=self.transpose_batch_sequence,
1279
            )
1280

1281
1282
1283
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1284
                wo_lora_a_kernel = self.param(
1285
                    "wo_lora_a_kernel",
1286
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1287
                    wo_lora_a_kernel_shape,
1288
                    self.dtype,
1289
                ).astype(input_dtype)
1290
1291
1292

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1293
                wo_lora_b_kernel = self.param(
1294
                    "wo_lora_b_kernel",
1295
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1296
                    wo_lora_b_kernel_shape,
1297
                    self.dtype,
1298
                ).astype(input_dtype)
1299

1300
1301
1302
1303
1304
1305
1306
1307
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1308

1309
            if self.use_bias:
1310
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1311

1312
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1313

1314
        assert out.dtype == input_dtype
1315
        return out, ln_output  # Output, layer_norm_output