module.py 50.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18

19
from ..dense import dense
20
21
22

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
23
24
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
25
from ..activation import activation
26
from ..softmax import softmax, SoftmaxType
27
from ..sharding import with_sharding_constraint_by_logical_axes
28
29
30
31
32
33
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
34
35
36
37
38
from ..quantize import (
    QuantizerFactory,
    get_quantize_config,
    QuantizeMetaSet,
    TensorSource,
39
    get_quantize_config_with_recipe,
40
)
41
42
43

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
44
45
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
46
47
48
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
49
50
51
52
53
54
55
56
57
58
59
60
61
62
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


63
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
64
65
66
67
68
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
69
70
71
    return nn.initializers.zeros


72
def _create_layernorm_parameters(
73
    module,
74
75
76
77
78
79
80
81
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
82
):
83
84
85
86
87
88
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
89

90
91
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
92
93
94
95
96
97
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
98
    else:
99
        assert norm_type == "rmsnorm"
100
101
102
103
104
105
106
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
107
    if fn_or_string == "linear":
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
122
123
124
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
125
126
127
128
129
130
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


131
132
133
134
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
135
    hidden_in_names = "ijklm"[: len(axis)]
136
    assert len(features) <= 5
137
138
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
139
140
141
142
143
144
145
146
147

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
148
149
150
151
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
152
153
154
155
156
157

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


158
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
159
160
    r"""
    Applies softmax over a mini-batch of inputs.
161
162
163
164
165
166
167
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
168
169
170
171

    Parameters
    ----------
    scale_factor : float, default = 1.0
172
173
174
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
175
176
177
178
179
180
181
182
183
184
185
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
186
        input_dtype = inputs.dtype
187
188
        logits = inputs

189
190
        # use primitives
        if is_softmax_kernel_available(
191
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
192
        ):
193
            if bias is not None:
194
                logits = logits + bias.astype(input_dtype)
195
196
197
198
199

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

200
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
201
        # use default jax based implementation
202
203
        else:
            if bias is not None:
204
                logits = logits + bias.astype(input_dtype)
205

206
207
208
209
210
211
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
212
            else:
213
214
215
216
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
217
        assert input_dtype == outputs.dtype
218
219
220
        return outputs


221
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
251
        A value added to the denominator of layer normalization for numerical stability.
252
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
253
        Indicate the type of layer normalization.
254
255
256
257
258
259
260
261
262
263
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
264
        Used for initializing scale factors :math:`\gamma`.
265
266
267
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
268
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
269
    scale_axes : Tuple[str, ...], default = ('embed', )
270
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
271
    bias_init : Initializer, default = flax.linen.initializers.zeros
272
273
274
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
275
276
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
277
        only used when :attr:`layernorm_type='layernorm'`.
278
279
280

    Optimization parameters
    -----------------------
281
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
282
        The data type used to allocate the initial parameters.
283
    """
284

285
    epsilon: float = 1e-6
286
    layernorm_type: str = "layernorm"
287
288
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
289
    scale_axes: Tuple[str, ...] = ("embed",)
290
    bias_init: Initializer = nn.initializers.zeros
291
    bias_axes: Tuple[str, ...] = ("embed",)
292
293
    dtype: DType = jnp.float32

294
    def __post_init__(self):
295
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
296
297
            self.scale_init,
            self.zero_centered_gamma,
298
        )
299
300
        super().__post_init__()

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
316
        input_dtype = x.dtype
317
318

        features = x.shape[-1]
319
        scale, ln_bias = _create_layernorm_parameters(
320
            self,
321
322
323
324
325
326
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
327
            input_dtype,
328
329
            self.dtype,
        )
330
        out = layernorm(
331
332
333
            x,
            scale,
            ln_bias,
334
            norm_type=self.layernorm_type,
335
336
337
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
338
339
        assert out.dtype == input_dtype
        return out
340
341
342


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
343
344
345
346
    """
    Base class of transformer engine
    """

347
348
349
    def generate_quantizer_set(
        self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
    ):
350
        """
351
        Generate a set of FP8 meta for a GEMM.
352
353
        """

354
355
356
357
358
359
360
361
        collection_name = (
            variable_collection
            if variable_collection is not None
            else get_quantize_config().COLLECTION_NAME
        )

        if fp8_recipe is None:
            quantize_config = get_quantize_config()
362
        else:
363
364
365
366
367
368
369
370
371
372
373
            quantize_config = get_quantize_config_with_recipe(fp8_recipe)

        x_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.X, "x"
        )
        kernel_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.KERNEL, "kernel"
        )
        grad_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.DGRAD, "grad"
        )
374

375
376
377
378
379
        quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)

        quantizer_set = QuantizerFactory.create_set(
            fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
        )
380
        return quantizer_set
381
382
383


class DenseGeneral(TransformerEngineBase):
384
    r"""
385
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
386
387
388
389

    Parameters
    ----------
    features : Union[Iterable[int], int]
390
        The hidden size of each output sample.
391
392
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
393
394
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
395
    kernel_axes : Tuple[str, ...], default = ()
396
        The name of axes used to shard the weights with a corresponding mesh.
397
    use_bias: bool, default = False
398
399
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
400
    bias_init: Initializer, default = flax.linen.initializers.zeros
401
402
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
403
    bias_axes: Tuple[str, ...], default = ()
404
405
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
406
    enable_low_rank_adaptation: bool, default = False
407
        Indicate whether to enable low rank adaptation for each dense layer.
408
409
410
411
412
413
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
414
    axis:  Union[Iterable[int], int], default = -1
415
        An integer tuple with axes to apply the transformation on.
416
417
418
419
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
420
421
422

    Optimization parameters
    -----------------------
423
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
424
        The data type used to allocate the initial parameters.
425
426
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
427
428
429
430
431
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
432
    use_bias: bool = True
433
434
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
435
436
437
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
438
439
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
440
    input_axes: Tuple[str, ...] = ()
441
    transpose_batch_sequence: bool = False
442
443
444

    def __post_init__(self):
        if self.kernel_init is None:
445
            self.kernel_init = nn.initializers.variance_scaling(
446
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
447
            )
448
449
450
451
452
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
453
        Apply the dense layer transformation to the input.
454
455
456
457
458
459
460
461
462
463
464

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
465

466
        input_dtype = inputs.dtype
467
468
469
470
471
472
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
473
474
475
476
477
478

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
479
480
481
482
483
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
484
        )
485

486
        if not get_quantize_config().is_fp8_enabled():
487
            kernel = kernel.astype(input_dtype)
488
489

        if self.use_bias:
490
491
492
493
494
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
495
            ).astype(input_dtype)
496
497
498
        else:
            bias = None

499
        quantizer_set = self.generate_quantizer_set()
500
        contract_ind = tuple(range(0, len(axis)))
501
        y = dense(
502
503
504
505
506
507
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
508
            transpose_batch_sequence=self.transpose_batch_sequence,
509
        )
510

511
        if self.enable_low_rank_adaptation:
512
513
514
515
516
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
517
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
518
            lora_a_kernel = self.param(
519
                "lora_a_kernel",
520
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
521
                lora_a_kernel_shape,
522
                self.dtype,
523
            ).astype(input_dtype)
524
525
526

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
527
            lora_b_kernel = self.param(
528
                "lora_b_kernel",
529
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
530
                lora_b_kernel_shape,
531
                self.dtype,
532
            ).astype(input_dtype)
533

534
535
536
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
537

538
        if bias is not None:
539
540
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
541
542

        assert y.dtype == input_dtype
543
544
545
546
547
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
548
    Applies layer normalization followed by dense layer transformation to the incoming data.
549
550
551
552

    Parameters
    ----------
    features : Union[Iterable[int], int]
553
        The hidden size of each output sample.
554
    enable_layernorm: bool, default = True
555
        Indicate whether to enable layer normalization before dense layer transformation.
556
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
557
        Indicate the type of layer normalization.
558
    epsilon : float, default = 1e-6
559
        A value added to the denominator of layer normalization for numerical stability.
560
561
562
563
564
565
566
567
568
569
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
570
        Used for initializing scale factors :math:`\gamma`.
571
572
573
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
574
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
575
    scale_axes : Tuple[str, ...], default = ('embed', )
576
577
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
578
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
579
580
581
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
582
583
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
584
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
585
586
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
587
588
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
589
    kernel_axes : Tuple[str, ...], default = ()
590
        The name of axes used to shard the weights with a corresponding mesh.
591
    use_bias: bool, default = False
592
593
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
594
    bias_init: Initializer, default = flax.linen.initializers.zeros
595
596
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
597
    bias_axes: Tuple[str, ...], default = ()
598
599
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
600
    return_layernorm_output: bool, default = True
601
        Indicate whether to return the output of layer normalization.
602
        If set False, return None as the second tensor in outputs.
603
    enable_low_rank_adaptation: bool, default = False
604
        Indicate whether to enable low rank adaptation for each dense layer.
605
606
607
608
609
610
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
611
    axis:  Union[Iterable[int], int], default = -1
612
        An integer tuple with axes to apply the transformation on.
613
614
615
616
617
618
619
620
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
621
622
623

    Optimization parameters
    -----------------------
624
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
625
        The data type used to allocate the initial parameters.
626
    depth_scaling: float, default = None
627
        The factor to scale the output from `DenseGeneral`. It should be a float
628
        value or None. When None is set, then no scaling is applied.
629
630
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
631
632
633
634
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
635
    layernorm_type: str = "layernorm"
636
    epsilon: float = 1e-6
637
638
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
639
    scale_axes: Tuple[str, ...] = ("embed",)
640
    ln_bias_init: Initializer = nn.initializers.zeros
641
    ln_bias_axes: Tuple[str, ...] = ("embed",)
642
643
644
645
646
647
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
648
649
650
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
651
652
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
653
654
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
655
    depth_scaling: float = None
656
    transpose_batch_sequence: bool = False
657
658
659

    def __post_init__(self):
        if self.kernel_init is None:
660
            self.kernel_init = nn.initializers.variance_scaling(
661
662
663
                1.0,
                "fan_in",
                "truncated_normal",
664
                dtype=self.dtype,
665
            )
666
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
667
668
            self.scale_init,
            self.zero_centered_gamma,
669
        )
670
        self.quantizer_set = QuantizerFactory.create_set()
671
672
673
674
675
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
676
        Apply layer normalization to the input followed by a dense layer transformation.
677
678
679
680
681
682
683
684
685
686
687
688

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
689
            If :attr:`return_layernorm_output=False`, then this would be None.
690
        """
691
        assert self.axis == -1, "Only support axis = =-1 at this moment"
692

693
        input_dtype = inputs.dtype
694
695
        ln_output = None

696
697
        quantizer_set = self.generate_quantizer_set()

698
        fuse_layernorm = (
699
            get_quantize_config().is_fp8_enabled()
700
701
702
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
703
704

        if self.enable_layernorm:
705
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
706
            features = inputs.shape[-1]
707
            scale, ln_bias = _create_layernorm_parameters(
708
                self,
709
710
711
712
713
714
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
715
                input_dtype,
716
717
                self.dtype,
            )
718
719

            if not fuse_layernorm:
720
721
722
723
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
724
                    norm_type=self.layernorm_type,
725
726
727
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

743
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
744
745
746
747
748
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
749
        )
750
        if not get_quantize_config().is_fp8_enabled():
751
            kernel = kernel.astype(input_dtype)
752
753
754

        contract_ind = tuple(range(0, len(axis)))

755
        if fuse_layernorm:
756
            z = layernorm_dense(
757
758
759
760
                y,
                kernel,
                scale,
                ln_bias,
761
                norm_type=self.layernorm_type,
762
763
764
765
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
766
                kernel_axes=self.kernel_axes,
767
                quantizer_set=quantizer_set,
768
                transpose_batch_sequence=self.transpose_batch_sequence,
769
            )
770
        else:
771
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
772
773
774
775
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
776
                transpose_batch_sequence=self.transpose_batch_sequence,
777
778
779
780
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
781

782
        if self.enable_low_rank_adaptation:
783
784
785
786
787
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
788
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
789
            lora_a_kernel = self.param(
790
                "lora_a_kernel",
791
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
792
                lora_a_kernel_shape,
793
                self.dtype,
794
            ).astype(input_dtype)
795
796
797

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
798
            lora_b_kernel = self.param(
799
                "lora_b_kernel",
800
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
801
                lora_b_kernel_shape,
802
                self.dtype,
803
            ).astype(input_dtype)
804

805
806
807
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
808

809
810
        bias = None
        if self.use_bias:
811
812
813
814
815
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
816
            ).astype(input_dtype)
817
818

        if bias is not None:
819
820
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
821
822
823
824

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

825
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
826
        # z = z.reshape(*inputs.shape[: self.axis], *features)
827
        return z, ln_output  # dense_output, layer_norm_output
828
829
830
831
832


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
833
    consisting of 2 successive dense layer transformations, separated by given activations.
834
835
836
837

    Parameters
    ----------
    intermediate_dim: int, default = 2048
838
        Intermediate size to which input samples are projected.
839
    enable_layernorm: bool, default = True
840
        Indicate whether to enable layer normalization before dense layer transformation.
841
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
842
        Indicate the type of layer normalization.
843
    epsilon : float, default = 1e-6
844
        A value added to the denominator of layer normalization for numerical stability.
845
846
847
848
849
850
851
852
853
854
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
855
        Used for initializing scale factors :math:`\gamma`.
856
857
858
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
859
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
860
    scale_axes : Tuple[str, ...], default = ('embed', )
861
862
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
863
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
864
865
866
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
867
868
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
869
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
870
871
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
872
        Used for initializing the weights of both dense layer transformations.
873
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
874
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
875
        The name of axes used to shard the weights with a corresponding mesh for
876
        the weight of the first dense layer transformation.
877
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
878
        The name of axes used to shard the weights with a corresponding mesh for
879
        the weight of the second dense layer transformation.
880
    use_bias: bool, default = False
881
882
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
883
    bias_init: Initializer, default = flax.linen.initializers.zeros
884
885
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
886
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
887
        The name of axes used to shard bias with a corresponding mesh  for
888
        the weight of the first dense layer transformation.
889
        Only used when :attr:`use_bias=True`.
890
    bias_axes_2: Tuple[str, ...], default = ('embed',)
891
        The name of axes used to shard bias with a corresponding mesh  for
892
        the weight of the second dense layer transformation.
893
        Only used when :attr:`use_bias=True`.
894
    return_layernorm_output: bool, default = True
895
        Indicate whether to return the output of layer normalization.
896
897
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
898
        The sequence of activation functions to apply after the first dense layer transformation.
899
        Each activation has its own transformation layer.
900
901
902
903
    activation_params: dict, default = None
        The parameters needed(if any) by the activation functions specified in :attr:`activations`.
        At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
        need additional parameters.
904
905
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
906
    intermediate_dropout_rate: float, default = 0.1
907
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
908
909
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
910
    enable_low_rank_adaptation: bool, default = False
911
        Indicate whether to enable low rank adaptation for each dense layer.
912
913
914
915
916
917
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
918
    axis:  Union[Iterable[int], int], default = -1
919
        An integer tuple with axes to apply the transformation on.
920
921
922
923
924
925
926
927
928
929
930
931
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
932
933
934
935
936
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

937
938
939

    Optimization parameters
    -----------------------
940
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
941
        The data type used to allocate the initial parameters.
942
943
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
944
945
946
947
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
948
    layernorm_type: str = "layernorm"
949
    epsilon: float = 1e-6
950
951
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
952
    scale_axes: Tuple[str, ...] = ("embed",)
953
    ln_bias_init: Initializer = nn.initializers.zeros
954
    ln_bias_axes: Tuple[str, ...] = ("embed",)
955
    kernel_init: Initializer = None
956
957
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
958
959
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
960
961
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
962
    return_layernorm_output: bool = True
963
    activations: Sequence[Union[str, Callable]] = ("relu",)
964
    activation_params: dict = None
965
    intermediate_dropout_rng_name: str = "dropout"
966
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
967
    intermediate_hidden_dropout_dims: Sequence[int] = ()
968
969
970
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
971
972
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
973
974
975
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
976
977
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
978
    transpose_batch_sequence: bool = False
979
980
981

    def __post_init__(self):
        if self.kernel_init is None:
982
            self.kernel_init = nn.initializers.variance_scaling(
983
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
984
            )
985
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
986
987
            self.scale_init,
            self.zero_centered_gamma,
988
        )
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1009
            If :attr:`return_layernorm_output=False`, then this would be None.
1010
        """
1011
1012
        assert self.axis == -1, "Only support axis == -1 at this moment"

1013
1014
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1015

1016
        input_dtype = inputs.dtype
1017
1018
        ln_output = None

1019
1020
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1021
        fuse_layernorm = (
1022
            get_quantize_config().is_fp8_enabled()
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
1033
            ("clamped_silu", "clamped_linear"),
1034
1035
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1036
        normalized_acts = []
1037
1038
1039
        for act in self.activations:
            if not isinstance(act, str):
                return False
1040
            normalized_acts.append(act.lower())
1041
        normalized_acts = tuple(
1042
1043
1044
            reversed(normalized_acts)
            if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
            else normalized_acts
1045
        )
1046

1047
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1048

1049
1050
1051
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1052
1053
        # LayerNorm
        if self.enable_layernorm:
1054
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1055

1056
1057
            features = inputs.shape[-1]

1058
            scale, ln_bias = _create_layernorm_parameters(
1059
                self,
1060
1061
1062
1063
1064
1065
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1066
                input_dtype,
1067
1068
                self.dtype,
            )
1069
1070

            if not fuse_layernorm:
1071
1072
1073
1074
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1075
                    norm_type=self.layernorm_type,
1076
1077
1078
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1093
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1094

1095
        num_activations = len(normalized_acts)
1096
1097
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1098
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1099
        kernel_1 = self.param(
1100
            "wi_kernel",
1101
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1102
1103
1104
            num_activations,
            -2,
            kernel_1_each_shape,
1105
            self.dtype,
1106
        )
1107

1108
        if not get_quantize_config().is_fp8_enabled():
1109
            kernel_1 = kernel_1.astype(input_dtype)
1110

1111
1112
1113
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1114
        kernel_2 = self.param(
1115
            "wo_kernel",
1116
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1117
            kernel_2_shape,
1118
            self.dtype,
1119
        )
1120
        if not get_quantize_config().is_fp8_enabled():
1121
            kernel_2 = kernel_2.astype(input_dtype)
1122

1123
        contract_ind = tuple(range(0, len(axis)))
1124

1125
        if self.use_bias:
1126
            bias_1_shape = (num_activations, self.intermediate_dim)
1127
            bias_1 = self.param(
1128
                "wi_bias",
1129
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1130
1131
                bias_1_shape,
                self.dtype,
1132
            ).astype(input_dtype)
1133
1134

            bias_2_shape = (hidden_size,)
1135
            bias_2 = self.param(
1136
                "wo_bias",
1137
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1138
1139
                bias_2_shape,
                self.dtype,
1140
            ).astype(input_dtype)
1141
1142
1143
1144
        else:
            bias_1 = None
            bias_2 = None

1145
        if use_fused_layernorm_mlp:
1146
            out = layernorm_mlp(
1147
1148
1149
1150
1151
1152
1153
1154
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1155
                norm_input_axes=self.layernorm_input_axes,
1156
1157
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1158
1159
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1160
1161
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1162
                activation_type=normalized_acts,
1163
                activation_params=self.activation_params,
1164
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1165
                transpose_batch_sequence=self.transpose_batch_sequence,
1166
            )
1167
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1168
1169

        else:  # not use_fused_ln_geglu_mlp
1170
1171
            # DenseGeneral 1
            if fuse_layernorm:
1172
                x = layernorm_dense(
1173
1174
1175
1176
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1177
                    norm_type=self.layernorm_type,
1178
1179
1180
1181
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1182
                    kernel_axes=self.kernel_axes_1,
1183
                    quantizer_set=ffn1_quantizer_set,
1184
                    transpose_batch_sequence=self.transpose_batch_sequence,
1185
                )
1186
            else:
1187
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1188
1189
1190
1191
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1192
1193
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1194
                    quantizer_set=ffn1_quantizer_set,
1195
                    transpose_batch_sequence=self.transpose_batch_sequence,
1196
                )
1197

1198
            if self.enable_low_rank_adaptation:
1199
1200
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1201
1202
                    self.low_rank_adaptation_dim,
                )
1203
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1204
                wi_lora_a_kernel = self.param(
1205
                    "wi_lora_a_kernel",
1206
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1207
                    num_activations,
1208
1209
                    -2,
                    wi_lora_a_kernel_each_shape,
1210
                    self.dtype,
1211
                ).astype(input_dtype)
1212

1213
1214
1215
1216
1217
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1218
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1219
                wi_lora_b_kernel = self.param(
1220
                    "wi_lora_b_kernel",
1221
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1222
                    wi_lora_b_kernel_shape,
1223
                    self.dtype,
1224
                ).astype(input_dtype)
1225

1226
1227
1228
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1229
                    (num_activations, self.intermediate_dim),
1230
1231
1232
1233
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1234

1235
            if self.use_bias:
1236
                x += jnp.reshape(bias_1, bias_1_shape)
1237

1238
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1239
            if is_act_implemented:
1240
                z = activation(x, normalized_acts)
1241
            else:
1242
                activations = []
1243
                x = jnp.split(x, num_activations, axis=-2)
1244
                for idx, act_fn in enumerate(normalized_acts):
1245
1246
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1247
                z = reduce(operator.mul, activations)
1248
                z = jnp.squeeze(z, axis=-2)
1249
            z = z.astype(input_dtype)
1250

1251
1252
1253
1254
1255
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1256

1257
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1258
            z = z.astype(input_dtype)
1259

1260
            # DenseGeneral 2
1261
            out = dense(
1262
1263
1264
1265
1266
1267
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1268
                transpose_batch_sequence=self.transpose_batch_sequence,
1269
            )
1270

1271
1272
1273
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1274
                wo_lora_a_kernel = self.param(
1275
                    "wo_lora_a_kernel",
1276
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1277
                    wo_lora_a_kernel_shape,
1278
                    self.dtype,
1279
                ).astype(input_dtype)
1280
1281
1282

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1283
                wo_lora_b_kernel = self.param(
1284
                    "wo_lora_b_kernel",
1285
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1286
                    wo_lora_b_kernel_shape,
1287
                    self.dtype,
1288
                ).astype(input_dtype)
1289

1290
1291
1292
1293
1294
1295
1296
1297
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1298

1299
            if self.use_bias:
1300
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1301

1302
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1303

1304
        assert out.dtype == input_dtype
1305
        return out, ln_output  # Output, layer_norm_output