module.py 49.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18
19
from transformer_engine.common import recipe

20
from ..dense import dense
21
22
23

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
24
25
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
26
from ..activation import activation
27
from ..softmax import softmax, SoftmaxType
28
from ..sharding import with_sharding_constraint_by_logical_axes
29
30
31
32
33
34
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
35
36
37
38
39
40
41
42
from ..quantize import (
    QuantizerFactory,
    get_quantize_config,
    QuantizeMeta,
    QuantizeMetaSet,
    ScalingMode,
    TensorSource,
)
43
44
45

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
46
47
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
48
49
50
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
51
52
53
54
55
56
57
58
59
60
61
62
63
64
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


65
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
66
67
68
69
70
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
71
72
73
    return nn.initializers.zeros


74
def _create_layernorm_parameters(
75
    module,
76
77
78
79
80
81
82
83
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
84
):
85
86
87
88
89
90
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
91

92
93
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
94
95
96
97
98
99
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
100
    else:
101
        assert norm_type == "rmsnorm"
102
103
104
105
106
107
108
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
109
    if fn_or_string == "linear":
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
124
125
126
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
127
128
129
130
131
132
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


133
134
135
136
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
137
    hidden_in_names = "ijklm"[: len(axis)]
138
    assert len(features) <= 5
139
140
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
141
142
143
144
145
146
147
148
149

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
150
151
152
153
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
154
155
156
157
158
159

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


160
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
161
162
    r"""
    Applies softmax over a mini-batch of inputs.
163
164
165
166
167
168
169
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
170
171
172
173

    Parameters
    ----------
    scale_factor : float, default = 1.0
174
175
176
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
177
178
179
180
181
182
183
184
185
186
187
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
188
        input_dtype = inputs.dtype
189
190
        logits = inputs

191
192
        # use primitives
        if is_softmax_kernel_available(
193
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
194
        ):
195
            if bias is not None:
196
                logits = logits + bias.astype(input_dtype)
197
198
199
200
201

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

202
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
203
        # use default jax based implementation
204
205
        else:
            if bias is not None:
206
                logits = logits + bias.astype(input_dtype)
207

208
209
210
211
212
213
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
214
            else:
215
216
217
218
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
219
        assert input_dtype == outputs.dtype
220
221
222
        return outputs


223
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
253
        A value added to the denominator of layer normalization for numerical stability.
254
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
255
        Indicate the type of layer normalization.
256
257
258
259
260
261
262
263
264
265
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
266
        Used for initializing scale factors :math:`\gamma`.
267
268
269
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
270
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
271
    scale_axes : Tuple[str, ...], default = ('embed', )
272
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
273
    bias_init : Initializer, default = flax.linen.initializers.zeros
274
275
276
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
277
278
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
279
        only used when :attr:`layernorm_type='layernorm'`.
280
281
282

    Optimization parameters
    -----------------------
283
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
284
        The data type used to allocate the initial parameters.
285
    """
286

287
    epsilon: float = 1e-6
288
    layernorm_type: str = "layernorm"
289
290
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
291
    scale_axes: Tuple[str, ...] = ("embed",)
292
    bias_init: Initializer = nn.initializers.zeros
293
    bias_axes: Tuple[str, ...] = ("embed",)
294
295
    dtype: DType = jnp.float32

296
    def __post_init__(self):
297
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
298
299
            self.scale_init,
            self.zero_centered_gamma,
300
        )
301
302
        super().__post_init__()

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
318
        input_dtype = x.dtype
319
320

        features = x.shape[-1]
321
        scale, ln_bias = _create_layernorm_parameters(
322
            self,
323
324
325
326
327
328
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
329
            input_dtype,
330
331
            self.dtype,
        )
332
        out = layernorm(
333
334
335
            x,
            scale,
            ln_bias,
336
            norm_type=self.layernorm_type,
337
338
339
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
340
341
        assert out.dtype == input_dtype
        return out
342
343
344


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
345
346
347
348
    """
    Base class of transformer engine
    """

349
350
351
    def generate_quantizer_set(
        self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
    ):
352
        """
353
        Generate a set of FP8 meta for a GEMM.
354
355
        """

356
        def generate_quantize_meta(quantizer_name: str):
357
358
359
            collection_name = (
                variable_collection
                if variable_collection is not None
360
                else get_quantize_config().COLLECTION_NAME
361
            )
362
            scale = self.variable(
363
                collection_name,
364
                f"{quantizer_name}{postfix}_scale",
365
366
                jnp.ones,
                (1,),
367
                jnp.float32,
368
369
            ).value
            amax_history = self.variable(
370
                collection_name,
371
372
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
373
                (get_quantize_config().AMAX_HISTORY_LEN,),
374
375
376
377
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

378
379
380
        if get_quantize_config().get_scaling_mode(
            TensorSource.X
        ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
381
382
383
384
385
386
387
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
388

389
        quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
390
        return quantizer_set
391
392
393


class DenseGeneral(TransformerEngineBase):
394
    r"""
395
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
396
397
398
399

    Parameters
    ----------
    features : Union[Iterable[int], int]
400
        The hidden size of each output sample.
401
402
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
403
404
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
405
    kernel_axes : Tuple[str, ...], default = ()
406
        The name of axes used to shard the weights with a corresponding mesh.
407
    use_bias: bool, default = False
408
409
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
410
    bias_init: Initializer, default = flax.linen.initializers.zeros
411
412
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
413
    bias_axes: Tuple[str, ...], default = ()
414
415
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
416
    enable_low_rank_adaptation: bool, default = False
417
        Indicate whether to enable low rank adaptation for each dense layer.
418
419
420
421
422
423
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
424
    axis:  Union[Iterable[int], int], default = -1
425
        An integer tuple with axes to apply the transformation on.
426
427
428
429
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
430
431
432

    Optimization parameters
    -----------------------
433
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
434
        The data type used to allocate the initial parameters.
435
436
437
438
439
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
440
    use_bias: bool = True
441
442
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
443
444
445
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
446
447
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
448
    input_axes: Tuple[str, ...] = ()
449
450
451

    def __post_init__(self):
        if self.kernel_init is None:
452
            self.kernel_init = nn.initializers.variance_scaling(
453
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
454
            )
455
456
457
458
459
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
460
        Apply the dense layer transformation to the input.
461
462
463
464
465
466
467
468
469
470
471

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
472

473
        input_dtype = inputs.dtype
474
475
476
477
478
479
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
480
481
482
483
484
485

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
486
487
488
489
490
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
491
        )
492

493
        if not get_quantize_config().is_fp8_enabled():
494
            kernel = kernel.astype(input_dtype)
495
496

        if self.use_bias:
497
498
499
500
501
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
502
            ).astype(input_dtype)
503
504
505
        else:
            bias = None

506
        quantizer_set = self.generate_quantizer_set()
507
        contract_ind = tuple(range(0, len(axis)))
508
        y = dense(
509
510
511
512
513
514
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
515
        )
516

517
        if self.enable_low_rank_adaptation:
518
519
520
521
522
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
523
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
524
            lora_a_kernel = self.param(
525
                "lora_a_kernel",
526
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
527
                lora_a_kernel_shape,
528
                self.dtype,
529
            ).astype(input_dtype)
530
531
532

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
533
            lora_b_kernel = self.param(
534
                "lora_b_kernel",
535
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
536
                lora_b_kernel_shape,
537
                self.dtype,
538
            ).astype(input_dtype)
539

540
541
542
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
543

544
        if bias is not None:
545
546
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
547
548

        assert y.dtype == input_dtype
549
550
551
552
553
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
554
    Applies layer normalization followed by dense layer transformation to the incoming data.
555
556
557
558

    Parameters
    ----------
    features : Union[Iterable[int], int]
559
        The hidden size of each output sample.
560
    enable_layernorm: bool, default = True
561
        Indicate whether to enable layer normalization before dense layer transformation.
562
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
563
        Indicate the type of layer normalization.
564
    epsilon : float, default = 1e-6
565
        A value added to the denominator of layer normalization for numerical stability.
566
567
568
569
570
571
572
573
574
575
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
576
        Used for initializing scale factors :math:`\gamma`.
577
578
579
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
580
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
581
    scale_axes : Tuple[str, ...], default = ('embed', )
582
583
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
584
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
585
586
587
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
588
589
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
590
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
591
592
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
593
594
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
595
    kernel_axes : Tuple[str, ...], default = ()
596
        The name of axes used to shard the weights with a corresponding mesh.
597
    use_bias: bool, default = False
598
599
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
600
    bias_init: Initializer, default = flax.linen.initializers.zeros
601
602
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
603
    bias_axes: Tuple[str, ...], default = ()
604
605
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
606
    return_layernorm_output: bool, default = True
607
        Indicate whether to return the output of layer normalization.
608
        If set False, return None as the second tensor in outputs.
609
    enable_low_rank_adaptation: bool, default = False
610
        Indicate whether to enable low rank adaptation for each dense layer.
611
612
613
614
615
616
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
617
    axis:  Union[Iterable[int], int], default = -1
618
        An integer tuple with axes to apply the transformation on.
619
620
621
622
623
624
625
626
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
627
628
629

    Optimization parameters
    -----------------------
630
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
631
        The data type used to allocate the initial parameters.
632
    depth_scaling: float, default = None
633
        The factor to scale the output from `DenseGeneral`. It should be a float
634
635
636
637
638
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
639
    layernorm_type: str = "layernorm"
640
    epsilon: float = 1e-6
641
642
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
643
    scale_axes: Tuple[str, ...] = ("embed",)
644
    ln_bias_init: Initializer = nn.initializers.zeros
645
    ln_bias_axes: Tuple[str, ...] = ("embed",)
646
647
648
649
650
651
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
652
653
654
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
655
656
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
657
658
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
659
660
661
662
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
663
            self.kernel_init = nn.initializers.variance_scaling(
664
665
666
                1.0,
                "fan_in",
                "truncated_normal",
667
                dtype=self.dtype,
668
            )
669
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
670
671
            self.scale_init,
            self.zero_centered_gamma,
672
        )
673
        self.quantizer_set = QuantizerFactory.create_set()
674
675
676
677
678
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
679
        Apply layer normalization to the input followed by a dense layer transformation.
680
681
682
683
684
685
686
687
688
689
690
691

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
692
            If :attr:`return_layernorm_output=False`, then this would be None.
693
        """
694
        assert self.axis == -1, "Only support axis = =-1 at this moment"
695

696
        input_dtype = inputs.dtype
697
698
        ln_output = None

699
700
        quantizer_set = self.generate_quantizer_set()

701
        fuse_layernorm = (
702
            get_quantize_config().is_fp8_enabled()
703
704
705
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
706
707

        if self.enable_layernorm:
708
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
709
            features = inputs.shape[-1]
710
            scale, ln_bias = _create_layernorm_parameters(
711
                self,
712
713
714
715
716
717
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
718
                input_dtype,
719
720
                self.dtype,
            )
721
722

            if not fuse_layernorm:
723
724
725
726
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
727
                    norm_type=self.layernorm_type,
728
729
730
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

746
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
747
748
749
750
751
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
752
        )
753
        if not get_quantize_config().is_fp8_enabled():
754
            kernel = kernel.astype(input_dtype)
755
756
757

        contract_ind = tuple(range(0, len(axis)))

758
        if fuse_layernorm:
759
            z = layernorm_dense(
760
761
762
763
                y,
                kernel,
                scale,
                ln_bias,
764
                norm_type=self.layernorm_type,
765
766
767
768
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
769
                kernel_axes=self.kernel_axes,
770
                quantizer_set=quantizer_set,
771
            )
772
        else:
773
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
774
775
776
777
778
779
780
781
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
782

783
        if self.enable_low_rank_adaptation:
784
785
786
787
788
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
789
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
790
            lora_a_kernel = self.param(
791
                "lora_a_kernel",
792
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
793
                lora_a_kernel_shape,
794
                self.dtype,
795
            ).astype(input_dtype)
796
797
798

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
799
            lora_b_kernel = self.param(
800
                "lora_b_kernel",
801
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
802
                lora_b_kernel_shape,
803
                self.dtype,
804
            ).astype(input_dtype)
805

806
807
808
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
809

810
811
        bias = None
        if self.use_bias:
812
813
814
815
816
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
817
            ).astype(input_dtype)
818
819

        if bias is not None:
820
821
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
822
823
824
825

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

826
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
827
        # z = z.reshape(*inputs.shape[: self.axis], *features)
828
        return z, ln_output  # dense_output, layer_norm_output
829
830
831
832
833


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
834
    consisting of 2 successive dense layer transformations, separated by given activations.
835
836
837
838

    Parameters
    ----------
    intermediate_dim: int, default = 2048
839
        Intermediate size to which input samples are projected.
840
    enable_layernorm: bool, default = True
841
        Indicate whether to enable layer normalization before dense layer transformation.
842
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
843
        Indicate the type of layer normalization.
844
    epsilon : float, default = 1e-6
845
        A value added to the denominator of layer normalization for numerical stability.
846
847
848
849
850
851
852
853
854
855
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
856
        Used for initializing scale factors :math:`\gamma`.
857
858
859
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
860
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
861
    scale_axes : Tuple[str, ...], default = ('embed', )
862
863
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
864
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
865
866
867
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
868
869
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
870
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
871
872
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
873
        Used for initializing the weights of both dense layer transformations.
874
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
875
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
876
        The name of axes used to shard the weights with a corresponding mesh for
877
        the weight of the first dense layer transformation.
878
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
879
        The name of axes used to shard the weights with a corresponding mesh for
880
        the weight of the second dense layer transformation.
881
    use_bias: bool, default = False
882
883
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
884
    bias_init: Initializer, default = flax.linen.initializers.zeros
885
886
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
887
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
888
        The name of axes used to shard bias with a corresponding mesh  for
889
        the weight of the first dense layer transformation.
890
        Only used when :attr:`use_bias=True`.
891
    bias_axes_2: Tuple[str, ...], default = ('embed',)
892
        The name of axes used to shard bias with a corresponding mesh  for
893
        the weight of the second dense layer transformation.
894
        Only used when :attr:`use_bias=True`.
895
    return_layernorm_output: bool, default = True
896
        Indicate whether to return the output of layer normalization.
897
898
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
899
        The sequence of activation functions to apply after the first dense layer transformation.
900
        Each activation has its own transformation layer.
901
902
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
903
    intermediate_dropout_rate: float, default = 0.1
904
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
905
906
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
907
    enable_low_rank_adaptation: bool, default = False
908
        Indicate whether to enable low rank adaptation for each dense layer.
909
910
911
912
913
914
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
915
    axis:  Union[Iterable[int], int], default = -1
916
        An integer tuple with axes to apply the transformation on.
917
918
919
920
921
922
923
924
925
926
927
928
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
929
930
931
932
933
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

934
935
936

    Optimization parameters
    -----------------------
937
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
938
        The data type used to allocate the initial parameters.
939
940
941
942
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
943
    layernorm_type: str = "layernorm"
944
    epsilon: float = 1e-6
945
946
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
947
    scale_axes: Tuple[str, ...] = ("embed",)
948
    ln_bias_init: Initializer = nn.initializers.zeros
949
    ln_bias_axes: Tuple[str, ...] = ("embed",)
950
    kernel_init: Initializer = None
951
952
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
953
954
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
955
956
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
957
    return_layernorm_output: bool = True
958
959
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
960
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
961
    intermediate_hidden_dropout_dims: Sequence[int] = ()
962
963
964
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
965
966
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
967
968
969
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
970
971
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
972
973
974

    def __post_init__(self):
        if self.kernel_init is None:
975
            self.kernel_init = nn.initializers.variance_scaling(
976
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
977
            )
978
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
979
980
            self.scale_init,
            self.zero_centered_gamma,
981
        )
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1002
            If :attr:`return_layernorm_output=False`, then this would be None.
1003
        """
1004
1005
        assert self.axis == -1, "Only support axis == -1 at this moment"

1006
1007
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1008

1009
        input_dtype = inputs.dtype
1010
1011
        ln_output = None

1012
1013
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1014
        fuse_layernorm = (
1015
            get_quantize_config().is_fp8_enabled()
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1028
        normalized_acts = []
1029
1030
1031
        for act in self.activations:
            if not isinstance(act, str):
                return False
1032
            normalized_acts.append(act.lower())
1033
        normalized_acts = tuple(
1034
1035
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1036

1037
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1038

1039
1040
1041
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1042
1043
        # LayerNorm
        if self.enable_layernorm:
1044
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1045

1046
1047
            features = inputs.shape[-1]

1048
            scale, ln_bias = _create_layernorm_parameters(
1049
                self,
1050
1051
1052
1053
1054
1055
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1056
                input_dtype,
1057
1058
                self.dtype,
            )
1059
1060

            if not fuse_layernorm:
1061
1062
1063
1064
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1065
                    norm_type=self.layernorm_type,
1066
1067
1068
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1083
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1084

1085
        num_activations = len(normalized_acts)
1086
1087
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1088
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1089
        kernel_1 = self.param(
1090
            "wi_kernel",
1091
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1092
1093
1094
            num_activations,
            -2,
            kernel_1_each_shape,
1095
            self.dtype,
1096
        )
1097

1098
        if not get_quantize_config().is_fp8_enabled():
1099
            kernel_1 = kernel_1.astype(input_dtype)
1100

1101
1102
1103
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1104
        kernel_2 = self.param(
1105
            "wo_kernel",
1106
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1107
            kernel_2_shape,
1108
            self.dtype,
1109
        )
1110
        if not get_quantize_config().is_fp8_enabled():
1111
            kernel_2 = kernel_2.astype(input_dtype)
1112

1113
        contract_ind = tuple(range(0, len(axis)))
1114

1115
        if self.use_bias:
1116
            bias_1_shape = (num_activations, self.intermediate_dim)
1117
            bias_1 = self.param(
1118
                "wi_bias",
1119
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1120
1121
                bias_1_shape,
                self.dtype,
1122
            ).astype(input_dtype)
1123
1124

            bias_2_shape = (hidden_size,)
1125
            bias_2 = self.param(
1126
                "wo_bias",
1127
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1128
1129
                bias_2_shape,
                self.dtype,
1130
            ).astype(input_dtype)
1131
1132
1133
1134
        else:
            bias_1 = None
            bias_2 = None

1135
        if use_fused_layernorm_mlp:
1136
            out = layernorm_mlp(
1137
1138
1139
1140
1141
1142
1143
1144
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1145
                norm_input_axes=self.layernorm_input_axes,
1146
1147
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1148
1149
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1150
1151
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1152
                activation_type=normalized_acts,
1153
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1154
            )
1155
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1156
1157

        else:  # not use_fused_ln_geglu_mlp
1158
1159
            # DenseGeneral 1
            if fuse_layernorm:
1160
                x = layernorm_dense(
1161
1162
1163
1164
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1165
                    norm_type=self.layernorm_type,
1166
1167
1168
1169
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1170
                    kernel_axes=self.kernel_axes_1,
1171
                    quantizer_set=ffn1_quantizer_set,
1172
                )
1173
            else:
1174
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1175
1176
1177
1178
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1179
1180
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1181
                    quantizer_set=ffn1_quantizer_set,
1182
                )
1183

1184
            if self.enable_low_rank_adaptation:
1185
1186
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1187
1188
                    self.low_rank_adaptation_dim,
                )
1189
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1190
                wi_lora_a_kernel = self.param(
1191
                    "wi_lora_a_kernel",
1192
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1193
                    num_activations,
1194
1195
                    -2,
                    wi_lora_a_kernel_each_shape,
1196
                    self.dtype,
1197
                ).astype(input_dtype)
1198

1199
1200
1201
1202
1203
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1204
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1205
                wi_lora_b_kernel = self.param(
1206
                    "wi_lora_b_kernel",
1207
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1208
                    wi_lora_b_kernel_shape,
1209
                    self.dtype,
1210
                ).astype(input_dtype)
1211

1212
1213
1214
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1215
                    (num_activations, self.intermediate_dim),
1216
1217
1218
1219
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1220

1221
            if self.use_bias:
1222
                x += jnp.reshape(bias_1, bias_1_shape)
1223

1224
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1225
            if is_act_implemented:
1226
                z = activation(x, normalized_acts)
1227
            else:
1228
                activations = []
1229
                x = jnp.split(x, num_activations, axis=-2)
1230
                for idx, act_fn in enumerate(normalized_acts):
1231
1232
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1233
                z = reduce(operator.mul, activations)
1234
                z = jnp.squeeze(z, axis=-2)
1235
            z = z.astype(input_dtype)
1236

1237
1238
1239
1240
1241
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1242

1243
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1244
            z = z.astype(input_dtype)
1245

1246
            # DenseGeneral 2
1247
            out = dense(
1248
1249
1250
1251
1252
1253
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1254
            )
1255

1256
1257
1258
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1259
                wo_lora_a_kernel = self.param(
1260
                    "wo_lora_a_kernel",
1261
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1262
                    wo_lora_a_kernel_shape,
1263
                    self.dtype,
1264
                ).astype(input_dtype)
1265
1266
1267

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1268
                wo_lora_b_kernel = self.param(
1269
                    "wo_lora_b_kernel",
1270
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1271
                    wo_lora_b_kernel_shape,
1272
                    self.dtype,
1273
                ).astype(input_dtype)
1274

1275
1276
1277
1278
1279
1280
1281
1282
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1283

1284
            if self.use_bias:
1285
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1286

1287
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1288

1289
        assert out.dtype == input_dtype
1290
        return out, ln_output  # Output, layner_norm_output