module.py 49.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18
19
from transformer_engine.common import recipe

20
from ..dense import dense
21
22
23

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
24
25
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
26
from ..activation import activation
27
from ..softmax import softmax, SoftmaxType
28
from ..sharding import with_sharding_constraint_by_logical_axes
29
30
31
32
33
34
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
35
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
36
37
38

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
39
40
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
41
42
43
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
44
45
46
47
48
49
50
51
52
53
54
55
56
57
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


58
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
59
60
61
62
63
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
64
65
66
    return nn.initializers.zeros


67
def _create_layernorm_parameters(
68
    module,
69
70
71
72
73
74
75
76
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
77
):
78
79
80
81
82
83
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
84

85
86
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
87
88
89
90
91
92
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
93
    else:
94
        assert norm_type == "rmsnorm"
95
96
97
98
99
100
101
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
102
    if fn_or_string == "linear":
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
117
118
119
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
120
121
122
123
124
125
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


126
127
128
129
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
130
    hidden_in_names = "ijklm"[: len(axis)]
131
    assert len(features) <= 5
132
133
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
134
135
136
137
138
139
140
141
142

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
143
144
145
146
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
147
148
149
150
151
152

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


153
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
154
155
    r"""
    Applies softmax over a mini-batch of inputs.
156
157
158
159
160
161
162
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
163
164
165
166

    Parameters
    ----------
    scale_factor : float, default = 1.0
167
168
169
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
170
171
172
173
174
175
176
177
178
179
180
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
181
        input_dtype = inputs.dtype
182
183
        logits = inputs

184
185
        # use primitives
        if is_softmax_kernel_available(
186
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
187
        ):
188
            if bias is not None:
189
                logits = logits + bias.astype(input_dtype)
190
191
192
193
194

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

195
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
196
        # use default jax based implementation
197
198
        else:
            if bias is not None:
199
                logits = logits + bias.astype(input_dtype)
200

201
202
203
204
205
206
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
207
            else:
208
209
210
211
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
212
        assert input_dtype == outputs.dtype
213
214
215
        return outputs


216
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
246
        A value added to the denominator of layer normalization for numerical stability.
247
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
248
        Indicate the type of layer normalization.
249
250
251
252
253
254
255
256
257
258
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
259
        Used for initializing scale factors :math:`\gamma`.
260
261
262
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
263
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
264
    scale_axes : Tuple[str, ...], default = ('embed', )
265
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
266
    bias_init : Initializer, default = flax.linen.initializers.zeros
267
268
269
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
270
271
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
272
        only used when :attr:`layernorm_type='layernorm'`.
273
274
275

    Optimization parameters
    -----------------------
276
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
277
        The data type used to allocate the initial parameters.
278
    """
279

280
    epsilon: float = 1e-6
281
    layernorm_type: str = "layernorm"
282
283
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
284
    scale_axes: Tuple[str, ...] = ("embed",)
285
    bias_init: Initializer = nn.initializers.zeros
286
    bias_axes: Tuple[str, ...] = ("embed",)
287
288
    dtype: DType = jnp.float32

289
    def __post_init__(self):
290
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
291
292
            self.scale_init,
            self.zero_centered_gamma,
293
        )
294
295
        super().__post_init__()

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
311
        input_dtype = x.dtype
312
313

        features = x.shape[-1]
314
        scale, ln_bias = _create_layernorm_parameters(
315
            self,
316
317
318
319
320
321
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
322
            input_dtype,
323
324
            self.dtype,
        )
325
        out = layernorm(
326
327
328
            x,
            scale,
            ln_bias,
329
            norm_type=self.layernorm_type,
330
331
332
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
333
334
        assert out.dtype == input_dtype
        return out
335
336
337


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
338
339
340
341
    """
    Base class of transformer engine
    """

342
343
344
    def generate_quantizer_set(
        self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
    ):
345
        """
346
        Generate a set of FP8 meta for a GEMM.
347
348
        """

349
        def generate_quantize_meta(quantizer_name: str):
350
351
352
353
354
            collection_name = (
                variable_collection
                if variable_collection is not None
                else QuantizeConfig.COLLECTION_NAME
            )
355
            scale = self.variable(
356
                collection_name,
357
                f"{quantizer_name}{postfix}_scale",
358
359
                jnp.ones,
                (1,),
360
                jnp.float32,
361
362
            ).value
            amax_history = self.variable(
363
                collection_name,
364
365
366
367
368
369
370
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

371
372
373
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
            fp8_recipe, recipe.DelayedScaling
        ):
374
375
376
377
378
379
380
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
381

382
        quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
383
        return quantizer_set
384
385
386


class DenseGeneral(TransformerEngineBase):
387
    r"""
388
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
389
390
391
392

    Parameters
    ----------
    features : Union[Iterable[int], int]
393
        The hidden size of each output sample.
394
395
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
396
397
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
398
    kernel_axes : Tuple[str, ...], default = ()
399
        The name of axes used to shard the weights with a corresponding mesh.
400
    use_bias: bool, default = False
401
402
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
403
    bias_init: Initializer, default = flax.linen.initializers.zeros
404
405
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
406
    bias_axes: Tuple[str, ...], default = ()
407
408
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
409
    enable_low_rank_adaptation: bool, default = False
410
        Indicate whether to enable low rank adaptation for each dense layer.
411
412
413
414
415
416
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
417
    axis:  Union[Iterable[int], int], default = -1
418
        An integer tuple with axes to apply the transformation on.
419
420
421
422
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
423
424
425

    Optimization parameters
    -----------------------
426
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
427
        The data type used to allocate the initial parameters.
428
429
430
431
432
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
433
    use_bias: bool = True
434
435
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
436
437
438
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
439
440
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
441
    input_axes: Tuple[str, ...] = ()
442
443
444

    def __post_init__(self):
        if self.kernel_init is None:
445
            self.kernel_init = nn.initializers.variance_scaling(
446
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
447
            )
448
449
450
451
452
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
453
        Apply the dense layer transformation to the input.
454
455
456
457
458
459
460
461
462
463
464

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
465

466
        input_dtype = inputs.dtype
467
468
469
470
471
472
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
473
474
475
476
477
478

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
479
480
481
482
483
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
484
        )
485

486
        if not QuantizeConfig.is_fp8_enabled():
487
            kernel = kernel.astype(input_dtype)
488
489

        if self.use_bias:
490
491
492
493
494
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
495
            ).astype(input_dtype)
496
497
498
        else:
            bias = None

499
        quantizer_set = self.generate_quantizer_set()
500
        contract_ind = tuple(range(0, len(axis)))
501
        y = dense(
502
503
504
505
506
507
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
508
        )
509

510
        if self.enable_low_rank_adaptation:
511
512
513
514
515
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
516
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
517
            lora_a_kernel = self.param(
518
                "lora_a_kernel",
519
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
520
                lora_a_kernel_shape,
521
                self.dtype,
522
            ).astype(input_dtype)
523
524
525

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
526
            lora_b_kernel = self.param(
527
                "lora_b_kernel",
528
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
529
                lora_b_kernel_shape,
530
                self.dtype,
531
            ).astype(input_dtype)
532

533
534
535
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
536

537
        if bias is not None:
538
539
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
540
541

        assert y.dtype == input_dtype
542
543
544
545
546
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
547
    Applies layer normalization followed by dense layer transformation to the incoming data.
548
549
550
551

    Parameters
    ----------
    features : Union[Iterable[int], int]
552
        The hidden size of each output sample.
553
    enable_layernorm: bool, default = True
554
        Indicate whether to enable layer normalization before dense layer transformation.
555
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
556
        Indicate the type of layer normalization.
557
    epsilon : float, default = 1e-6
558
        A value added to the denominator of layer normalization for numerical stability.
559
560
561
562
563
564
565
566
567
568
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
569
        Used for initializing scale factors :math:`\gamma`.
570
571
572
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
573
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
574
    scale_axes : Tuple[str, ...], default = ('embed', )
575
576
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
577
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
578
579
580
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
581
582
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
583
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
584
585
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
586
587
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
588
    kernel_axes : Tuple[str, ...], default = ()
589
        The name of axes used to shard the weights with a corresponding mesh.
590
    use_bias: bool, default = False
591
592
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
593
    bias_init: Initializer, default = flax.linen.initializers.zeros
594
595
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
596
    bias_axes: Tuple[str, ...], default = ()
597
598
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
599
    return_layernorm_output: bool, default = True
600
        Indicate whether to return the output of layer normalization.
601
        If set False, return None as the second tensor in outputs.
602
    enable_low_rank_adaptation: bool, default = False
603
        Indicate whether to enable low rank adaptation for each dense layer.
604
605
606
607
608
609
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
610
    axis:  Union[Iterable[int], int], default = -1
611
        An integer tuple with axes to apply the transformation on.
612
613
614
615
616
617
618
619
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
620
621
622

    Optimization parameters
    -----------------------
623
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
624
        The data type used to allocate the initial parameters.
625
    depth_scaling: float, default = None
626
        The factor to scale the output from `DenseGeneral`. It should be a float
627
628
629
630
631
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
632
    layernorm_type: str = "layernorm"
633
    epsilon: float = 1e-6
634
635
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
636
    scale_axes: Tuple[str, ...] = ("embed",)
637
    ln_bias_init: Initializer = nn.initializers.zeros
638
    ln_bias_axes: Tuple[str, ...] = ("embed",)
639
640
641
642
643
644
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
645
646
647
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
648
649
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
650
651
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
652
653
654
655
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
656
            self.kernel_init = nn.initializers.variance_scaling(
657
658
659
                1.0,
                "fan_in",
                "truncated_normal",
660
                dtype=self.dtype,
661
            )
662
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
663
664
            self.scale_init,
            self.zero_centered_gamma,
665
        )
666
        self.quantizer_set = QuantizerFactory.create_set()
667
668
669
670
671
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
672
        Apply layer normalization to the input followed by a dense layer transformation.
673
674
675
676
677
678
679
680
681
682
683
684

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
685
            If :attr:`return_layernorm_output=False`, then this would be None.
686
        """
687
        assert self.axis == -1, "Only support axis = =-1 at this moment"
688

689
        input_dtype = inputs.dtype
690
691
        ln_output = None

692
693
        quantizer_set = self.generate_quantizer_set()

694
        fuse_layernorm = (
695
            QuantizeConfig.is_fp8_enabled()
696
697
698
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
699
700

        if self.enable_layernorm:
701
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
702
            features = inputs.shape[-1]
703
            scale, ln_bias = _create_layernorm_parameters(
704
                self,
705
706
707
708
709
710
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
711
                input_dtype,
712
713
                self.dtype,
            )
714
715

            if not fuse_layernorm:
716
717
718
719
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
720
                    norm_type=self.layernorm_type,
721
722
723
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

739
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
740
741
742
743
744
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
745
        )
746
        if not QuantizeConfig.is_fp8_enabled():
747
            kernel = kernel.astype(input_dtype)
748
749
750

        contract_ind = tuple(range(0, len(axis)))

751
        if fuse_layernorm:
752
            z = layernorm_dense(
753
754
755
756
                y,
                kernel,
                scale,
                ln_bias,
757
                norm_type=self.layernorm_type,
758
759
760
761
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
762
                kernel_axes=self.kernel_axes,
763
                quantizer_set=quantizer_set,
764
            )
765
        else:
766
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
767
768
769
770
771
772
773
774
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
775

776
        if self.enable_low_rank_adaptation:
777
778
779
780
781
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
782
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
783
            lora_a_kernel = self.param(
784
                "lora_a_kernel",
785
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
786
                lora_a_kernel_shape,
787
                self.dtype,
788
            ).astype(input_dtype)
789
790
791

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
792
            lora_b_kernel = self.param(
793
                "lora_b_kernel",
794
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
795
                lora_b_kernel_shape,
796
                self.dtype,
797
            ).astype(input_dtype)
798

799
800
801
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
802

803
804
        bias = None
        if self.use_bias:
805
806
807
808
809
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
810
            ).astype(input_dtype)
811
812

        if bias is not None:
813
814
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
815
816
817
818

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

819
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
820
        # z = z.reshape(*inputs.shape[: self.axis], *features)
821
        return z, ln_output  # dense_output, layer_norm_output
822
823
824
825
826


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
827
    consisting of 2 successive dense layer transformations, separated by given activations.
828
829
830
831

    Parameters
    ----------
    intermediate_dim: int, default = 2048
832
        Intermediate size to which input samples are projected.
833
    enable_layernorm: bool, default = True
834
        Indicate whether to enable layer normalization before dense layer transformation.
835
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
836
        Indicate the type of layer normalization.
837
    epsilon : float, default = 1e-6
838
        A value added to the denominator of layer normalization for numerical stability.
839
840
841
842
843
844
845
846
847
848
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
849
        Used for initializing scale factors :math:`\gamma`.
850
851
852
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
853
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
854
    scale_axes : Tuple[str, ...], default = ('embed', )
855
856
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
857
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
858
859
860
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
861
862
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
863
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
864
865
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
866
        Used for initializing the weights of both dense layer transformations.
867
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
868
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
869
        The name of axes used to shard the weights with a corresponding mesh for
870
        the weight of the first dense layer transformation.
871
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
872
        The name of axes used to shard the weights with a corresponding mesh for
873
        the weight of the second dense layer transformation.
874
    use_bias: bool, default = False
875
876
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
877
    bias_init: Initializer, default = flax.linen.initializers.zeros
878
879
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
880
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
881
        The name of axes used to shard bias with a corresponding mesh  for
882
        the weight of the first dense layer transformation.
883
        Only used when :attr:`use_bias=True`.
884
    bias_axes_2: Tuple[str, ...], default = ('embed',)
885
        The name of axes used to shard bias with a corresponding mesh  for
886
        the weight of the second dense layer transformation.
887
        Only used when :attr:`use_bias=True`.
888
    return_layernorm_output: bool, default = True
889
        Indicate whether to return the output of layer normalization.
890
891
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
892
        The sequence of activation functions to apply after the first dense layer transformation.
893
        Each activation has its own transformation layer.
894
895
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
896
    intermediate_dropout_rate: float, default = 0.1
897
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
898
899
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
900
    enable_low_rank_adaptation: bool, default = False
901
        Indicate whether to enable low rank adaptation for each dense layer.
902
903
904
905
906
907
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
908
    axis:  Union[Iterable[int], int], default = -1
909
        An integer tuple with axes to apply the transformation on.
910
911
912
913
914
915
916
917
918
919
920
921
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
922
923
924
925
926
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

927
928
929

    Optimization parameters
    -----------------------
930
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
931
        The data type used to allocate the initial parameters.
932
933
934
935
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
936
    layernorm_type: str = "layernorm"
937
    epsilon: float = 1e-6
938
939
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
940
    scale_axes: Tuple[str, ...] = ("embed",)
941
    ln_bias_init: Initializer = nn.initializers.zeros
942
    ln_bias_axes: Tuple[str, ...] = ("embed",)
943
    kernel_init: Initializer = None
944
945
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
946
947
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
948
949
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
950
    return_layernorm_output: bool = True
951
952
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
953
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
954
    intermediate_hidden_dropout_dims: Sequence[int] = ()
955
956
957
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
958
959
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
960
961
962
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
963
964
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
965
966
967

    def __post_init__(self):
        if self.kernel_init is None:
968
            self.kernel_init = nn.initializers.variance_scaling(
969
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
970
            )
971
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
972
973
            self.scale_init,
            self.zero_centered_gamma,
974
        )
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
995
            If :attr:`return_layernorm_output=False`, then this would be None.
996
        """
997
998
        assert self.axis == -1, "Only support axis == -1 at this moment"

999
1000
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1001

1002
        input_dtype = inputs.dtype
1003
1004
        ln_output = None

1005
1006
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1007
        fuse_layernorm = (
1008
            QuantizeConfig.is_fp8_enabled()
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1021
        normalized_acts = []
1022
1023
1024
        for act in self.activations:
            if not isinstance(act, str):
                return False
1025
            normalized_acts.append(act.lower())
1026
        normalized_acts = tuple(
1027
1028
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1029

1030
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1031

1032
1033
1034
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1035
1036
        # LayerNorm
        if self.enable_layernorm:
1037
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1038

1039
1040
            features = inputs.shape[-1]

1041
            scale, ln_bias = _create_layernorm_parameters(
1042
                self,
1043
1044
1045
1046
1047
1048
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1049
                input_dtype,
1050
1051
                self.dtype,
            )
1052
1053

            if not fuse_layernorm:
1054
1055
1056
1057
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1058
                    norm_type=self.layernorm_type,
1059
1060
1061
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1076
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1077

1078
        num_activations = len(normalized_acts)
1079
1080
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1081
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1082
        kernel_1 = self.param(
1083
            "wi_kernel",
1084
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1085
1086
1087
            num_activations,
            -2,
            kernel_1_each_shape,
1088
            self.dtype,
1089
        )
1090

1091
        if not QuantizeConfig.is_fp8_enabled():
1092
            kernel_1 = kernel_1.astype(input_dtype)
1093

1094
1095
1096
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1097
        kernel_2 = self.param(
1098
            "wo_kernel",
1099
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1100
            kernel_2_shape,
1101
            self.dtype,
1102
        )
1103
        if not QuantizeConfig.is_fp8_enabled():
1104
            kernel_2 = kernel_2.astype(input_dtype)
1105

1106
        contract_ind = tuple(range(0, len(axis)))
1107

1108
        if self.use_bias:
1109
            bias_1_shape = (num_activations, self.intermediate_dim)
1110
            bias_1 = self.param(
1111
                "wi_bias",
1112
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1113
1114
                bias_1_shape,
                self.dtype,
1115
            ).astype(input_dtype)
1116
1117

            bias_2_shape = (hidden_size,)
1118
            bias_2 = self.param(
1119
                "wo_bias",
1120
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1121
1122
                bias_2_shape,
                self.dtype,
1123
            ).astype(input_dtype)
1124
1125
1126
1127
        else:
            bias_1 = None
            bias_2 = None

1128
        if use_fused_layernorm_mlp:
1129
            out = layernorm_mlp(
1130
1131
1132
1133
1134
1135
1136
1137
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1138
                norm_input_axes=self.layernorm_input_axes,
1139
1140
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1141
1142
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1143
1144
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1145
                activation_type=normalized_acts,
1146
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1147
            )
1148
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1149
1150

        else:  # not use_fused_ln_geglu_mlp
1151
1152
            # DenseGeneral 1
            if fuse_layernorm:
1153
                x = layernorm_dense(
1154
1155
1156
1157
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1158
                    norm_type=self.layernorm_type,
1159
1160
1161
1162
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1163
                    kernel_axes=self.kernel_axes_1,
1164
                    quantizer_set=ffn1_quantizer_set,
1165
                )
1166
            else:
1167
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1168
1169
1170
1171
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1172
1173
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1174
                    quantizer_set=ffn1_quantizer_set,
1175
                )
1176

1177
            if self.enable_low_rank_adaptation:
1178
1179
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1180
1181
                    self.low_rank_adaptation_dim,
                )
1182
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1183
                wi_lora_a_kernel = self.param(
1184
                    "wi_lora_a_kernel",
1185
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1186
                    num_activations,
1187
1188
                    -2,
                    wi_lora_a_kernel_each_shape,
1189
                    self.dtype,
1190
                ).astype(input_dtype)
1191

1192
1193
1194
1195
1196
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1197
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1198
                wi_lora_b_kernel = self.param(
1199
                    "wi_lora_b_kernel",
1200
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1201
                    wi_lora_b_kernel_shape,
1202
                    self.dtype,
1203
                ).astype(input_dtype)
1204

1205
1206
1207
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1208
                    (num_activations, self.intermediate_dim),
1209
1210
1211
1212
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1213

1214
            if self.use_bias:
1215
                x += jnp.reshape(bias_1, bias_1_shape)
1216

1217
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1218
            if is_act_implemented:
1219
                z = activation(x, normalized_acts)
1220
            else:
1221
                activations = []
1222
                x = jnp.split(x, num_activations, axis=-2)
1223
                for idx, act_fn in enumerate(normalized_acts):
1224
1225
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1226
                z = reduce(operator.mul, activations)
1227
                z = jnp.squeeze(z, axis=-2)
1228
            z = z.astype(input_dtype)
1229

1230
1231
1232
1233
1234
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1235

1236
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1237
            z = z.astype(input_dtype)
1238

1239
            # DenseGeneral 2
1240
            out = dense(
1241
1242
1243
1244
1245
1246
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1247
            )
1248

1249
1250
1251
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1252
                wo_lora_a_kernel = self.param(
1253
                    "wo_lora_a_kernel",
1254
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1255
                    wo_lora_a_kernel_shape,
1256
                    self.dtype,
1257
                ).astype(input_dtype)
1258
1259
1260

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1261
                wo_lora_b_kernel = self.param(
1262
                    "wo_lora_b_kernel",
1263
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1264
                    wo_lora_b_kernel_shape,
1265
                    self.dtype,
1266
                ).astype(input_dtype)
1267

1268
1269
1270
1271
1272
1273
1274
1275
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1276

1277
            if self.use_bias:
1278
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1279

1280
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1281

1282
        assert out.dtype == input_dtype
1283
        return out, ln_output  # Output, layner_norm_output