"vllm/vscode:/vscode.git/clone" did not exist on "cb08cd0d754f5ac352e31bf362062ba61403eb02"
module.py 49.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18
from ..dense import dense
19
20
21

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
22
23
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
24
from ..activation import activation
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
28
29
30
31
32
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
33
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34
35
36

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
37
38
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
39
40
41
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
42
43
44
45
46
47
48
49
50
51
52
53
54
55
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


56
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
57
58
59
60
61
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
62
63
64
    return nn.initializers.zeros


65
def _create_layernorm_parameters(
66
    module,
67
68
69
70
71
72
73
74
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
75
):
76
77
78
79
80
81
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
82

83
84
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
85
86
87
88
89
90
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
91
    else:
92
        assert norm_type == "rmsnorm"
93
94
95
96
97
98
99
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
100
    if fn_or_string == "linear":
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
115
116
117
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
118
119
120
121
122
123
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


124
125
126
127
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
128
    hidden_in_names = "ijklm"[: len(axis)]
129
    assert len(features) <= 5
130
131
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
132
133
134
135
136
137
138
139
140

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
141
142
143
144
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
145
146
147
148
149
150

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


151
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
152
153
    r"""
    Applies softmax over a mini-batch of inputs.
154
155
156
157
158
159
160
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
161
162
163
164

    Parameters
    ----------
    scale_factor : float, default = 1.0
165
166
167
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
168
169
170
171
172
173
174
175
176
177
178
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
179
        input_dtype = inputs.dtype
180
181
        logits = inputs

182
183
        # use primitives
        if is_softmax_kernel_available(
184
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
185
        ):
186
            if bias is not None:
187
                logits = logits + bias.astype(input_dtype)
188
189
190
191
192

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

193
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
194
        # use default jax based implementation
195
196
        else:
            if bias is not None:
197
                logits = logits + bias.astype(input_dtype)
198

199
200
201
202
203
204
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
205
            else:
206
207
208
209
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
210
        assert input_dtype == outputs.dtype
211
212
213
        return outputs


214
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
244
        A value added to the denominator of layer normalization for numerical stability.
245
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
246
        Indicate the type of layer normalization.
247
248
249
250
251
252
253
254
255
256
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
257
        Used for initializing scale factors :math:`\gamma`.
258
259
260
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
261
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
262
    scale_axes : Tuple[str, ...], default = ('embed', )
263
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
264
    bias_init : Initializer, default = flax.linen.initializers.zeros
265
266
267
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
268
269
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
270
        only used when :attr:`layernorm_type='layernorm'`.
271
272
273

    Optimization parameters
    -----------------------
274
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
275
        The data type used to allocate the initial parameters.
276
    """
277

278
    epsilon: float = 1e-6
279
    layernorm_type: str = "layernorm"
280
281
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
282
    scale_axes: Tuple[str, ...] = ("embed",)
283
    bias_init: Initializer = nn.initializers.zeros
284
    bias_axes: Tuple[str, ...] = ("embed",)
285
286
    dtype: DType = jnp.float32

287
    def __post_init__(self):
288
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
289
290
            self.scale_init,
            self.zero_centered_gamma,
291
        )
292
293
        super().__post_init__()

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
309
        input_dtype = x.dtype
310
311

        features = x.shape[-1]
312
        scale, ln_bias = _create_layernorm_parameters(
313
            self,
314
315
316
317
318
319
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
320
            input_dtype,
321
322
            self.dtype,
        )
323
        out = layernorm(
324
325
326
            x,
            scale,
            ln_bias,
327
            norm_type=self.layernorm_type,
328
329
330
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
331
332
        assert out.dtype == input_dtype
        return out
333
334
335


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
336
337
338
339
    """
    Base class of transformer engine
    """

340
341
342
    def generate_quantizer_set(
        self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
    ):
343
        """
344
        Generate a set of FP8 meta for a GEMM.
345
346
        """

347
        def generate_quantize_meta(quantizer_name: str):
348
349
350
351
352
            collection_name = (
                variable_collection
                if variable_collection is not None
                else QuantizeConfig.COLLECTION_NAME
            )
353
            scale = self.variable(
354
                collection_name,
355
                f"{quantizer_name}{postfix}_scale",
356
357
                jnp.ones,
                (1,),
358
                jnp.float32,
359
360
            ).value
            amax_history = self.variable(
361
                collection_name,
362
363
364
365
366
367
368
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

369
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
370
371
372
373
374
375
376
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
377

378
        quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
379
        return quantizer_set
380
381
382


class DenseGeneral(TransformerEngineBase):
383
    r"""
384
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
385
386
387
388

    Parameters
    ----------
    features : Union[Iterable[int], int]
389
        The hidden size of each output sample.
390
391
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
392
393
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
394
    kernel_axes : Tuple[str, ...], default = ()
395
        The name of axes used to shard the weights with a corresponding mesh.
396
    use_bias: bool, default = False
397
398
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
399
    bias_init: Initializer, default = flax.linen.initializers.zeros
400
401
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
402
    bias_axes: Tuple[str, ...], default = ()
403
404
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
405
    enable_low_rank_adaptation: bool, default = False
406
        Indicate whether to enable low rank adaptation for each dense layer.
407
408
409
410
411
412
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
413
    axis:  Union[Iterable[int], int], default = -1
414
        An integer tuple with axes to apply the transformation on.
415
416
417
418
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
419
420
421

    Optimization parameters
    -----------------------
422
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
423
        The data type used to allocate the initial parameters.
424
425
426
427
428
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
429
    use_bias: bool = True
430
431
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
432
433
434
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
435
436
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
437
    input_axes: Tuple[str, ...] = ()
438
439
440

    def __post_init__(self):
        if self.kernel_init is None:
441
            self.kernel_init = nn.initializers.variance_scaling(
442
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
443
            )
444
445
446
447
448
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
449
        Apply the dense layer transformation to the input.
450
451
452
453
454
455
456
457
458
459
460

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
461

462
        input_dtype = inputs.dtype
463
464
465
466
467
468
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
469
470
471
472
473
474

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
475
476
477
478
479
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
480
        )
481

482
        if not QuantizeConfig.is_fp8_enabled():
483
            kernel = kernel.astype(input_dtype)
484
485

        if self.use_bias:
486
487
488
489
490
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
491
            ).astype(input_dtype)
492
493
494
        else:
            bias = None

495
        quantizer_set = self.generate_quantizer_set()
496
        contract_ind = tuple(range(0, len(axis)))
497
        y = dense(
498
499
500
501
502
503
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
504
        )
505

506
        if self.enable_low_rank_adaptation:
507
508
509
510
511
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
512
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
513
            lora_a_kernel = self.param(
514
                "lora_a_kernel",
515
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
516
                lora_a_kernel_shape,
517
                self.dtype,
518
            ).astype(input_dtype)
519
520
521

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
522
            lora_b_kernel = self.param(
523
                "lora_b_kernel",
524
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
525
                lora_b_kernel_shape,
526
                self.dtype,
527
            ).astype(input_dtype)
528

529
530
531
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
532

533
        if bias is not None:
534
535
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
536
537

        assert y.dtype == input_dtype
538
539
540
541
542
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
543
    Applies layer normalization followed by dense layer transformation to the incoming data.
544
545
546
547

    Parameters
    ----------
    features : Union[Iterable[int], int]
548
        The hidden size of each output sample.
549
    enable_layernorm: bool, default = True
550
        Indicate whether to enable layer normalization before dense layer transformation.
551
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
552
        Indicate the type of layer normalization.
553
    epsilon : float, default = 1e-6
554
        A value added to the denominator of layer normalization for numerical stability.
555
556
557
558
559
560
561
562
563
564
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
565
        Used for initializing scale factors :math:`\gamma`.
566
567
568
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
569
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
570
    scale_axes : Tuple[str, ...], default = ('embed', )
571
572
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
573
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
574
575
576
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
577
578
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
579
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
580
581
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
582
583
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
584
    kernel_axes : Tuple[str, ...], default = ()
585
        The name of axes used to shard the weights with a corresponding mesh.
586
    use_bias: bool, default = False
587
588
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
589
    bias_init: Initializer, default = flax.linen.initializers.zeros
590
591
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
592
    bias_axes: Tuple[str, ...], default = ()
593
594
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
595
    return_layernorm_output: bool, default = True
596
        Indicate whether to return the output of layer normalization.
597
        If set False, return None as the second tensor in outputs.
598
    enable_low_rank_adaptation: bool, default = False
599
        Indicate whether to enable low rank adaptation for each dense layer.
600
601
602
603
604
605
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
606
    axis:  Union[Iterable[int], int], default = -1
607
        An integer tuple with axes to apply the transformation on.
608
609
610
611
612
613
614
615
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
616
617
618

    Optimization parameters
    -----------------------
619
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
620
        The data type used to allocate the initial parameters.
621
    depth_scaling: float, default = None
622
        The factor to scale the output from `DenseGeneral`. It should be a float
623
624
625
626
627
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
628
    layernorm_type: str = "layernorm"
629
    epsilon: float = 1e-6
630
631
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
632
    scale_axes: Tuple[str, ...] = ("embed",)
633
    ln_bias_init: Initializer = nn.initializers.zeros
634
    ln_bias_axes: Tuple[str, ...] = ("embed",)
635
636
637
638
639
640
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
641
642
643
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
644
645
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
646
647
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
648
649
650
651
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
652
            self.kernel_init = nn.initializers.variance_scaling(
653
654
655
                1.0,
                "fan_in",
                "truncated_normal",
656
                dtype=self.dtype,
657
            )
658
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
659
660
            self.scale_init,
            self.zero_centered_gamma,
661
        )
662
        self.quantizer_set = QuantizerFactory.create_set()
663
664
665
666
667
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
668
        Apply layer normalization to the input followed by a dense layer transformation.
669
670
671
672
673
674
675
676
677
678
679
680

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
681
            If :attr:`return_layernorm_output=False`, then this would be None.
682
        """
683
        assert self.axis == -1, "Only support axis = =-1 at this moment"
684

685
        input_dtype = inputs.dtype
686
687
        ln_output = None

688
689
        quantizer_set = self.generate_quantizer_set()

690
        fuse_layernorm = (
691
            QuantizeConfig.is_fp8_enabled()
692
693
694
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
695
696

        if self.enable_layernorm:
697
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
698
            features = inputs.shape[-1]
699
            scale, ln_bias = _create_layernorm_parameters(
700
                self,
701
702
703
704
705
706
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
707
                input_dtype,
708
709
                self.dtype,
            )
710
711

            if not fuse_layernorm:
712
713
714
715
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
716
                    norm_type=self.layernorm_type,
717
718
719
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

735
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
736
737
738
739
740
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
741
        )
742
        if not QuantizeConfig.is_fp8_enabled():
743
            kernel = kernel.astype(input_dtype)
744
745
746

        contract_ind = tuple(range(0, len(axis)))

747
        if fuse_layernorm:
748
            z = layernorm_dense(
749
750
751
752
                y,
                kernel,
                scale,
                ln_bias,
753
                norm_type=self.layernorm_type,
754
755
756
757
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
758
                kernel_axes=self.kernel_axes,
759
                quantizer_set=quantizer_set,
760
            )
761
        else:
762
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
763
764
765
766
767
768
769
770
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
771

772
        if self.enable_low_rank_adaptation:
773
774
775
776
777
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
778
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
779
            lora_a_kernel = self.param(
780
                "lora_a_kernel",
781
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
782
                lora_a_kernel_shape,
783
                self.dtype,
784
            ).astype(input_dtype)
785
786
787

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
788
            lora_b_kernel = self.param(
789
                "lora_b_kernel",
790
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
791
                lora_b_kernel_shape,
792
                self.dtype,
793
            ).astype(input_dtype)
794

795
796
797
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
798

799
800
        bias = None
        if self.use_bias:
801
802
803
804
805
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
806
            ).astype(input_dtype)
807
808

        if bias is not None:
809
810
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
811
812
813
814

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

815
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
816
        # z = z.reshape(*inputs.shape[: self.axis], *features)
817
        return z, ln_output  # dense_output, layer_norm_output
818
819
820
821
822


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
823
    consisting of 2 successive dense layer transformations, separated by given activations.
824
825
826
827

    Parameters
    ----------
    intermediate_dim: int, default = 2048
828
        Intermediate size to which input samples are projected.
829
    enable_layernorm: bool, default = True
830
        Indicate whether to enable layer normalization before dense layer transformation.
831
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
832
        Indicate the type of layer normalization.
833
    epsilon : float, default = 1e-6
834
        A value added to the denominator of layer normalization for numerical stability.
835
836
837
838
839
840
841
842
843
844
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
845
        Used for initializing scale factors :math:`\gamma`.
846
847
848
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
849
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
850
    scale_axes : Tuple[str, ...], default = ('embed', )
851
852
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
853
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
854
855
856
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
857
858
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
859
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
860
861
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
862
        Used for initializing the weights of both dense layer transformations.
863
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
864
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
865
        The name of axes used to shard the weights with a corresponding mesh for
866
        the weight of the first dense layer transformation.
867
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
868
        The name of axes used to shard the weights with a corresponding mesh for
869
        the weight of the second dense layer transformation.
870
    use_bias: bool, default = False
871
872
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
873
    bias_init: Initializer, default = flax.linen.initializers.zeros
874
875
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
876
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
877
        The name of axes used to shard bias with a corresponding mesh  for
878
        the weight of the first dense layer transformation.
879
        Only used when :attr:`use_bias=True`.
880
    bias_axes_2: Tuple[str, ...], default = ('embed',)
881
        The name of axes used to shard bias with a corresponding mesh  for
882
        the weight of the second dense layer transformation.
883
        Only used when :attr:`use_bias=True`.
884
    return_layernorm_output: bool, default = True
885
        Indicate whether to return the output of layer normalization.
886
887
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
888
        The sequence of activation functions to apply after the first dense layer transformation.
889
        Each activation has its own transformation layer.
890
891
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
892
    intermediate_dropout_rate: float, default = 0.1
893
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
894
895
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
896
    enable_low_rank_adaptation: bool, default = False
897
        Indicate whether to enable low rank adaptation for each dense layer.
898
899
900
901
902
903
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
904
    axis:  Union[Iterable[int], int], default = -1
905
        An integer tuple with axes to apply the transformation on.
906
907
908
909
910
911
912
913
914
915
916
917
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
918
919
920
921
922
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

923
924
925

    Optimization parameters
    -----------------------
926
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
927
        The data type used to allocate the initial parameters.
928
929
930
931
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
932
    layernorm_type: str = "layernorm"
933
    epsilon: float = 1e-6
934
935
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
936
    scale_axes: Tuple[str, ...] = ("embed",)
937
    ln_bias_init: Initializer = nn.initializers.zeros
938
    ln_bias_axes: Tuple[str, ...] = ("embed",)
939
    kernel_init: Initializer = None
940
941
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
942
943
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
944
945
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
946
    return_layernorm_output: bool = True
947
948
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
949
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
950
    intermediate_hidden_dropout_dims: Sequence[int] = ()
951
952
953
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
954
955
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
956
957
958
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
959
960
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
961
962
963

    def __post_init__(self):
        if self.kernel_init is None:
964
            self.kernel_init = nn.initializers.variance_scaling(
965
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
966
            )
967
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
968
969
            self.scale_init,
            self.zero_centered_gamma,
970
        )
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
991
            If :attr:`return_layernorm_output=False`, then this would be None.
992
        """
993
994
        assert self.axis == -1, "Only support axis == -1 at this moment"

995
996
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
997

998
        input_dtype = inputs.dtype
999
1000
        ln_output = None

1001
1002
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1003
        fuse_layernorm = (
1004
            QuantizeConfig.is_fp8_enabled()
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1017
        normalized_acts = []
1018
1019
1020
        for act in self.activations:
            if not isinstance(act, str):
                return False
1021
            normalized_acts.append(act.lower())
1022
        normalized_acts = tuple(
1023
1024
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1025

1026
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1027

1028
1029
1030
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1031
1032
        # LayerNorm
        if self.enable_layernorm:
1033
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1034

1035
1036
            features = inputs.shape[-1]

1037
            scale, ln_bias = _create_layernorm_parameters(
1038
                self,
1039
1040
1041
1042
1043
1044
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1045
                input_dtype,
1046
1047
                self.dtype,
            )
1048
1049

            if not fuse_layernorm:
1050
1051
1052
1053
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1054
                    norm_type=self.layernorm_type,
1055
1056
1057
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1072
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1073

1074
        num_activations = len(normalized_acts)
1075
1076
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1077
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1078
        kernel_1 = self.param(
1079
            "wi_kernel",
1080
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1081
1082
1083
            num_activations,
            -2,
            kernel_1_each_shape,
1084
            self.dtype,
1085
        )
1086

1087
        if not QuantizeConfig.is_fp8_enabled():
1088
            kernel_1 = kernel_1.astype(input_dtype)
1089

1090
1091
1092
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1093
        kernel_2 = self.param(
1094
            "wo_kernel",
1095
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1096
            kernel_2_shape,
1097
            self.dtype,
1098
        )
1099
        if not QuantizeConfig.is_fp8_enabled():
1100
            kernel_2 = kernel_2.astype(input_dtype)
1101

1102
        contract_ind = tuple(range(0, len(axis)))
1103

1104
        if self.use_bias:
1105
            bias_1_shape = (num_activations, self.intermediate_dim)
1106
            bias_1 = self.param(
1107
                "wi_bias",
1108
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1109
1110
                bias_1_shape,
                self.dtype,
1111
            ).astype(input_dtype)
1112
1113

            bias_2_shape = (hidden_size,)
1114
            bias_2 = self.param(
1115
                "wo_bias",
1116
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1117
1118
                bias_2_shape,
                self.dtype,
1119
            ).astype(input_dtype)
1120
1121
1122
1123
        else:
            bias_1 = None
            bias_2 = None

1124
        if use_fused_layernorm_mlp:
1125
            out = layernorm_mlp(
1126
1127
1128
1129
1130
1131
1132
1133
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1134
                norm_input_axes=self.layernorm_input_axes,
1135
1136
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1137
1138
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1139
1140
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1141
                activation_type=normalized_acts,
1142
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1143
            )
1144
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1145
1146

        else:  # not use_fused_ln_geglu_mlp
1147
1148
            # DenseGeneral 1
            if fuse_layernorm:
1149
                x = layernorm_dense(
1150
1151
1152
1153
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1154
                    norm_type=self.layernorm_type,
1155
1156
1157
1158
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1159
                    kernel_axes=self.kernel_axes_1,
1160
                    quantizer_set=ffn1_quantizer_set,
1161
                )
1162
            else:
1163
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1164
1165
1166
1167
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1168
1169
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1170
                    quantizer_set=ffn1_quantizer_set,
1171
                )
1172

1173
            if self.enable_low_rank_adaptation:
1174
1175
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1176
1177
                    self.low_rank_adaptation_dim,
                )
1178
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1179
                wi_lora_a_kernel = self.param(
1180
                    "wi_lora_a_kernel",
1181
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1182
                    num_activations,
1183
1184
                    -2,
                    wi_lora_a_kernel_each_shape,
1185
                    self.dtype,
1186
                ).astype(input_dtype)
1187

1188
1189
1190
1191
1192
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1193
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1194
                wi_lora_b_kernel = self.param(
1195
                    "wi_lora_b_kernel",
1196
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1197
                    wi_lora_b_kernel_shape,
1198
                    self.dtype,
1199
                ).astype(input_dtype)
1200

1201
1202
1203
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1204
                    (num_activations, self.intermediate_dim),
1205
1206
1207
1208
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1209

1210
            if self.use_bias:
1211
                x += jnp.reshape(bias_1, bias_1_shape)
1212

1213
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1214
            if is_act_implemented:
1215
                z = activation(x, normalized_acts)
1216
            else:
1217
                activations = []
1218
                x = jnp.split(x, num_activations, axis=-2)
1219
                for idx, act_fn in enumerate(normalized_acts):
1220
1221
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1222
                z = reduce(operator.mul, activations)
1223
                z = jnp.squeeze(z, axis=-2)
1224
            z = z.astype(input_dtype)
1225

1226
1227
1228
1229
1230
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1231

1232
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1233
            z = z.astype(input_dtype)
1234

1235
            # DenseGeneral 2
1236
            out = dense(
1237
1238
1239
1240
1241
1242
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1243
            )
1244

1245
1246
1247
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1248
                wo_lora_a_kernel = self.param(
1249
                    "wo_lora_a_kernel",
1250
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1251
                    wo_lora_a_kernel_shape,
1252
                    self.dtype,
1253
                ).astype(input_dtype)
1254
1255
1256

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1257
                wo_lora_b_kernel = self.param(
1258
                    "wo_lora_b_kernel",
1259
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1260
                    wo_lora_b_kernel_shape,
1261
                    self.dtype,
1262
                ).astype(input_dtype)
1263

1264
1265
1266
1267
1268
1269
1270
1271
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1272

1273
            if self.use_bias:
1274
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1275

1276
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1277

1278
        assert out.dtype == input_dtype
1279
        return out, ln_output  # Output, layner_norm_output