"vscode:/vscode.git/clone" did not exist on "2729087efe1f3e75383eceee09273168da1b8809"
module.py 49.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18
from ..dense import dense
19
20
21

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
22
23
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
24
from ..activation import activation
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
28
29
30
31
32
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
33
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34
35
36

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
37
38
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
39
40
41
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
42
43
44
45
46
47
48
49
50
51
52
53
54
55
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


56
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
57
58
59
60
61
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
62
63
64
    return nn.initializers.zeros


65
def _create_layernorm_parameters(
66
    module,
67
68
69
70
71
72
73
74
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
75
):
76
77
78
79
80
81
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
82

83
84
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
85
86
87
88
89
90
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
91
    else:
92
        assert norm_type == "rmsnorm"
93
94
95
96
97
98
99
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
100
    if fn_or_string == "linear":
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
115
116
117
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
118
119
120
121
122
123
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


124
125
126
127
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
128
    hidden_in_names = "ijklm"[: len(axis)]
129
    assert len(features) <= 5
130
131
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
132
133
134
135
136
137
138
139
140

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
141
142
143
144
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
145
146
147
148
149
150

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


151
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
152
153
    r"""
    Applies softmax over a mini-batch of inputs.
154
155
156
157
158
159
160
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
161
162
163
164

    Parameters
    ----------
    scale_factor : float, default = 1.0
165
166
167
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
168
169
170
171
172
173
174
175
176
177
178
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
179
        input_dtype = inputs.dtype
180
181
        logits = inputs

182
183
        # use primitives
        if is_softmax_kernel_available(
184
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
185
        ):
186
            if bias is not None:
187
                logits = logits + bias.astype(input_dtype)
188
189
190
191
192

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

193
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
194
        # use default jax based implementation
195
196
        else:
            if bias is not None:
197
                logits = logits + bias.astype(input_dtype)
198

199
200
201
202
203
204
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
205
            else:
206
207
208
209
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
210
        assert input_dtype == outputs.dtype
211
212
213
        return outputs


214
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
244
        A value added to the denominator of layer normalization for numerical stability.
245
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
246
        Indicate the type of layer normalization.
247
248
249
250
251
252
253
254
255
256
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
257
        Used for initializing scale factors :math:`\gamma`.
258
259
260
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
261
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
262
    scale_axes : Tuple[str, ...], default = ('embed', )
263
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
264
    bias_init : Initializer, default = flax.linen.initializers.zeros
265
266
267
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
268
269
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
270
        only used when :attr:`layernorm_type='layernorm'`.
271
272
273

    Optimization parameters
    -----------------------
274
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
275
        The data type used to allocate the initial parameters.
276
    """
277

278
    epsilon: float = 1e-6
279
    layernorm_type: str = "layernorm"
280
281
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
282
    scale_axes: Tuple[str, ...] = ("embed",)
283
    bias_init: Initializer = nn.initializers.zeros
284
    bias_axes: Tuple[str, ...] = ("embed",)
285
286
    dtype: DType = jnp.float32

287
    def __post_init__(self):
288
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
289
290
            self.scale_init,
            self.zero_centered_gamma,
291
        )
292
293
        super().__post_init__()

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
309
        input_dtype = x.dtype
310
311

        features = x.shape[-1]
312
        scale, ln_bias = _create_layernorm_parameters(
313
            self,
314
315
316
317
318
319
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
320
            input_dtype,
321
322
            self.dtype,
        )
323
        out = layernorm(
324
325
326
            x,
            scale,
            ln_bias,
327
            norm_type=self.layernorm_type,
328
329
330
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
331
332
        assert out.dtype == input_dtype
        return out
333
334
335


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
336
337
338
339
    """
    Base class of transformer engine
    """

340
    def generate_quantizer_set(self, postfix: str = ""):
341
        """
342
        Generate a set of FP8 meta for a GEMM.
343
344
        """

345
346
347
348
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
349
350
                jnp.ones,
                (1,),
351
                jnp.float32,
352
353
354
355
356
357
358
359
360
361
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

362
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
363
364
365
366
367
368
369
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
370

371
372
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
373
374
375


class DenseGeneral(TransformerEngineBase):
376
    r"""
377
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
378
379
380
381

    Parameters
    ----------
    features : Union[Iterable[int], int]
382
        The hidden size of each output sample.
383
384
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
385
386
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
387
    kernel_axes : Tuple[str, ...], default = ()
388
        The name of axes used to shard the weights with a corresponding mesh.
389
    use_bias: bool, default = False
390
391
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
392
    bias_init: Initializer, default = flax.linen.initializers.zeros
393
394
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
395
    bias_axes: Tuple[str, ...], default = ()
396
397
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
398
    enable_low_rank_adaptation: bool, default = False
399
        Indicate whether to enable low rank adaptation for each dense layer.
400
401
402
403
404
405
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
406
    axis:  Union[Iterable[int], int], default = -1
407
        An integer tuple with axes to apply the transformation on.
408
409
410
411
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
412
413
414

    Optimization parameters
    -----------------------
415
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
416
        The data type used to allocate the initial parameters.
417
418
419
420
421
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
422
    use_bias: bool = True
423
424
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
425
426
427
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
428
429
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
430
    input_axes: Tuple[str, ...] = ()
431
432
433

    def __post_init__(self):
        if self.kernel_init is None:
434
            self.kernel_init = nn.initializers.variance_scaling(
435
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
436
            )
437
438
439
440
441
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
442
        Apply the dense layer transformation to the input.
443
444
445
446
447
448
449
450
451
452
453

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
454

455
        input_dtype = inputs.dtype
456
457
458
459
460
461
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
462
463
464
465
466
467

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
468
469
470
471
472
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
473
        )
474

475
        if not QuantizeConfig.is_fp8_enabled():
476
            kernel = kernel.astype(input_dtype)
477
478

        if self.use_bias:
479
480
481
482
483
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
484
            ).astype(input_dtype)
485
486
487
        else:
            bias = None

488
        quantizer_set = self.generate_quantizer_set()
489
        contract_ind = tuple(range(0, len(axis)))
490
        y = dense(
491
492
493
494
495
496
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
497
        )
498

499
        if self.enable_low_rank_adaptation:
500
501
502
503
504
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
505
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
506
            lora_a_kernel = self.param(
507
                "lora_a_kernel",
508
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
509
                lora_a_kernel_shape,
510
                self.dtype,
511
            ).astype(input_dtype)
512
513
514

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
515
            lora_b_kernel = self.param(
516
                "lora_b_kernel",
517
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
518
                lora_b_kernel_shape,
519
                self.dtype,
520
            ).astype(input_dtype)
521

522
523
524
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
525

526
        if bias is not None:
527
528
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
529
530

        assert y.dtype == input_dtype
531
532
533
534
535
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
536
    Applies layer normalization followed by dense layer transformation to the incoming data.
537
538
539
540

    Parameters
    ----------
    features : Union[Iterable[int], int]
541
        The hidden size of each output sample.
542
    enable_layernorm: bool, default = True
543
        Indicate whether to enable layer normalization before dense layer transformation.
544
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
545
        Indicate the type of layer normalization.
546
    epsilon : float, default = 1e-6
547
        A value added to the denominator of layer normalization for numerical stability.
548
549
550
551
552
553
554
555
556
557
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
558
        Used for initializing scale factors :math:`\gamma`.
559
560
561
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
562
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
563
    scale_axes : Tuple[str, ...], default = ('embed', )
564
565
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
566
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
567
568
569
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
570
571
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
572
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
573
574
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
575
576
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
577
    kernel_axes : Tuple[str, ...], default = ()
578
        The name of axes used to shard the weights with a corresponding mesh.
579
    use_bias: bool, default = False
580
581
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
582
    bias_init: Initializer, default = flax.linen.initializers.zeros
583
584
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
585
    bias_axes: Tuple[str, ...], default = ()
586
587
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
588
    return_layernorm_output: bool, default = True
589
        Indicate whether to return the output of layer normalization.
590
        If set False, return None as the second tensor in outputs.
591
    enable_low_rank_adaptation: bool, default = False
592
        Indicate whether to enable low rank adaptation for each dense layer.
593
594
595
596
597
598
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
599
    axis:  Union[Iterable[int], int], default = -1
600
        An integer tuple with axes to apply the transformation on.
601
602
603
604
605
606
607
608
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
609
610
611

    Optimization parameters
    -----------------------
612
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
613
        The data type used to allocate the initial parameters.
614
    depth_scaling: float, default = None
615
        The factor to scale the output from `DenseGeneral`. It should be a float
616
617
618
619
620
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
621
    layernorm_type: str = "layernorm"
622
    epsilon: float = 1e-6
623
624
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
625
    scale_axes: Tuple[str, ...] = ("embed",)
626
    ln_bias_init: Initializer = nn.initializers.zeros
627
    ln_bias_axes: Tuple[str, ...] = ("embed",)
628
629
630
631
632
633
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
634
635
636
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
637
638
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
639
640
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
641
642
643
644
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
645
            self.kernel_init = nn.initializers.variance_scaling(
646
647
648
                1.0,
                "fan_in",
                "truncated_normal",
649
                dtype=self.dtype,
650
            )
651
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
652
653
            self.scale_init,
            self.zero_centered_gamma,
654
        )
655
        self.quantizer_set = QuantizerFactory.create_set()
656
657
658
659
660
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
661
        Apply layer normalization to the input followed by a dense layer transformation.
662
663
664
665
666
667
668
669
670
671
672
673

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
674
            If :attr:`return_layernorm_output=False`, then this would be None.
675
        """
676
        assert self.axis == -1, "Only support axis = =-1 at this moment"
677

678
        input_dtype = inputs.dtype
679
680
        ln_output = None

681
682
        quantizer_set = self.generate_quantizer_set()

683
        fuse_layernorm = (
684
            QuantizeConfig.is_fp8_enabled()
685
686
687
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
688
689

        if self.enable_layernorm:
690
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
691
            features = inputs.shape[-1]
692
            scale, ln_bias = _create_layernorm_parameters(
693
                self,
694
695
696
697
698
699
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
700
                input_dtype,
701
702
                self.dtype,
            )
703
704

            if not fuse_layernorm:
705
706
707
708
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
709
                    norm_type=self.layernorm_type,
710
711
712
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

728
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
729
730
731
732
733
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
734
        )
735
        if not QuantizeConfig.is_fp8_enabled():
736
            kernel = kernel.astype(input_dtype)
737
738
739

        contract_ind = tuple(range(0, len(axis)))

740
        if fuse_layernorm:
741
            z = layernorm_dense(
742
743
744
745
                y,
                kernel,
                scale,
                ln_bias,
746
                norm_type=self.layernorm_type,
747
748
749
750
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
751
                kernel_axes=self.kernel_axes,
752
                quantizer_set=quantizer_set,
753
            )
754
        else:
755
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
756
757
758
759
760
761
762
763
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
764

765
        if self.enable_low_rank_adaptation:
766
767
768
769
770
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
771
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
772
            lora_a_kernel = self.param(
773
                "lora_a_kernel",
774
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
775
                lora_a_kernel_shape,
776
                self.dtype,
777
            ).astype(input_dtype)
778
779
780

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
781
            lora_b_kernel = self.param(
782
                "lora_b_kernel",
783
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
784
                lora_b_kernel_shape,
785
                self.dtype,
786
            ).astype(input_dtype)
787

788
789
790
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
791

792
793
        bias = None
        if self.use_bias:
794
795
796
797
798
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
799
            ).astype(input_dtype)
800
801

        if bias is not None:
802
803
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
804
805
806
807

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

808
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
809
        # z = z.reshape(*inputs.shape[: self.axis], *features)
810
        return z, ln_output  # dense_output, layer_norm_output
811
812
813
814
815


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
816
    consisting of 2 successive dense layer transformations, separated by given activations.
817
818
819
820

    Parameters
    ----------
    intermediate_dim: int, default = 2048
821
        Intermediate size to which input samples are projected.
822
    enable_layernorm: bool, default = True
823
        Indicate whether to enable layer normalization before dense layer transformation.
824
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
825
        Indicate the type of layer normalization.
826
    epsilon : float, default = 1e-6
827
        A value added to the denominator of layer normalization for numerical stability.
828
829
830
831
832
833
834
835
836
837
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
838
        Used for initializing scale factors :math:`\gamma`.
839
840
841
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
842
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
843
    scale_axes : Tuple[str, ...], default = ('embed', )
844
845
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
846
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
847
848
849
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
850
851
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
852
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
853
854
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
855
        Used for initializing the weights of both dense layer transformations.
856
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
857
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
858
        The name of axes used to shard the weights with a corresponding mesh for
859
        the weight of the first dense layer transformation.
860
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
861
        The name of axes used to shard the weights with a corresponding mesh for
862
        the weight of the second dense layer transformation.
863
    use_bias: bool, default = False
864
865
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
866
    bias_init: Initializer, default = flax.linen.initializers.zeros
867
868
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
869
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
870
        The name of axes used to shard bias with a corresponding mesh  for
871
        the weight of the first dense layer transformation.
872
        Only used when :attr:`use_bias=True`.
873
    bias_axes_2: Tuple[str, ...], default = ('embed',)
874
        The name of axes used to shard bias with a corresponding mesh  for
875
        the weight of the second dense layer transformation.
876
        Only used when :attr:`use_bias=True`.
877
    return_layernorm_output: bool, default = True
878
        Indicate whether to return the output of layer normalization.
879
880
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
881
        The sequence of activation functions to apply after the first dense layer transformation.
882
        Each activation has its own transformation layer.
883
884
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
885
    intermediate_dropout_rate: float, default = 0.1
886
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
887
888
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
889
    enable_low_rank_adaptation: bool, default = False
890
        Indicate whether to enable low rank adaptation for each dense layer.
891
892
893
894
895
896
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
897
    axis:  Union[Iterable[int], int], default = -1
898
        An integer tuple with axes to apply the transformation on.
899
900
901
902
903
904
905
906
907
908
909
910
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
911
912
913
914
915
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

916
917
918

    Optimization parameters
    -----------------------
919
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
920
        The data type used to allocate the initial parameters.
921
922
923
924
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
925
    layernorm_type: str = "layernorm"
926
    epsilon: float = 1e-6
927
928
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
929
    scale_axes: Tuple[str, ...] = ("embed",)
930
    ln_bias_init: Initializer = nn.initializers.zeros
931
    ln_bias_axes: Tuple[str, ...] = ("embed",)
932
    kernel_init: Initializer = None
933
934
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
935
936
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
937
938
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
939
    return_layernorm_output: bool = True
940
941
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
942
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
943
    intermediate_hidden_dropout_dims: Sequence[int] = ()
944
945
946
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
947
948
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
949
950
951
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
952
953
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
954
955
956

    def __post_init__(self):
        if self.kernel_init is None:
957
            self.kernel_init = nn.initializers.variance_scaling(
958
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
959
            )
960
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
961
962
            self.scale_init,
            self.zero_centered_gamma,
963
        )
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
984
            If :attr:`return_layernorm_output=False`, then this would be None.
985
        """
986
987
        assert self.axis == -1, "Only support axis == -1 at this moment"

988
989
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
990

991
        input_dtype = inputs.dtype
992
993
        ln_output = None

994
995
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
996
        fuse_layernorm = (
997
            QuantizeConfig.is_fp8_enabled()
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1010
        normalized_acts = []
1011
1012
1013
        for act in self.activations:
            if not isinstance(act, str):
                return False
1014
            normalized_acts.append(act.lower())
1015
        normalized_acts = tuple(
1016
1017
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1018

1019
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1020

1021
1022
1023
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1024
1025
        # LayerNorm
        if self.enable_layernorm:
1026
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1027

1028
1029
            features = inputs.shape[-1]

1030
            scale, ln_bias = _create_layernorm_parameters(
1031
                self,
1032
1033
1034
1035
1036
1037
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1038
                input_dtype,
1039
1040
                self.dtype,
            )
1041
1042

            if not fuse_layernorm:
1043
1044
1045
1046
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1047
                    norm_type=self.layernorm_type,
1048
1049
1050
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1065
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1066

1067
        num_activations = len(normalized_acts)
1068
1069
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1070
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1071
        kernel_1 = self.param(
1072
            "wi_kernel",
1073
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1074
1075
1076
            num_activations,
            -2,
            kernel_1_each_shape,
1077
            self.dtype,
1078
        )
1079

1080
        if not QuantizeConfig.is_fp8_enabled():
1081
            kernel_1 = kernel_1.astype(input_dtype)
1082

1083
1084
1085
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1086
        kernel_2 = self.param(
1087
            "wo_kernel",
1088
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1089
            kernel_2_shape,
1090
            self.dtype,
1091
        )
1092
        if not QuantizeConfig.is_fp8_enabled():
1093
            kernel_2 = kernel_2.astype(input_dtype)
1094

1095
        contract_ind = tuple(range(0, len(axis)))
1096

1097
        if self.use_bias:
1098
            bias_1_shape = (num_activations, self.intermediate_dim)
1099
            bias_1 = self.param(
1100
                "wi_bias",
1101
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1102
1103
                bias_1_shape,
                self.dtype,
1104
            ).astype(input_dtype)
1105
1106

            bias_2_shape = (hidden_size,)
1107
            bias_2 = self.param(
1108
                "wo_bias",
1109
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1110
1111
                bias_2_shape,
                self.dtype,
1112
            ).astype(input_dtype)
1113
1114
1115
1116
        else:
            bias_1 = None
            bias_2 = None

1117
        if use_fused_layernorm_mlp:
1118
            out = layernorm_mlp(
1119
1120
1121
1122
1123
1124
1125
1126
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1127
                norm_input_axes=self.layernorm_input_axes,
1128
1129
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1130
1131
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1132
1133
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1134
                activation_type=normalized_acts,
1135
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1136
            )
1137
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1138
1139

        else:  # not use_fused_ln_geglu_mlp
1140
1141
            # DenseGeneral 1
            if fuse_layernorm:
1142
                x = layernorm_dense(
1143
1144
1145
1146
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1147
                    norm_type=self.layernorm_type,
1148
1149
1150
1151
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1152
                    kernel_axes=self.kernel_axes_1,
1153
                    quantizer_set=ffn1_quantizer_set,
1154
                )
1155
            else:
1156
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1157
1158
1159
1160
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1161
1162
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1163
                    quantizer_set=ffn1_quantizer_set,
1164
                )
1165

1166
            if self.enable_low_rank_adaptation:
1167
1168
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1169
1170
                    self.low_rank_adaptation_dim,
                )
1171
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1172
                wi_lora_a_kernel = self.param(
1173
                    "wi_lora_a_kernel",
1174
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1175
                    num_activations,
1176
1177
                    -2,
                    wi_lora_a_kernel_each_shape,
1178
                    self.dtype,
1179
                ).astype(input_dtype)
1180

1181
1182
1183
1184
1185
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1186
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1187
                wi_lora_b_kernel = self.param(
1188
                    "wi_lora_b_kernel",
1189
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1190
                    wi_lora_b_kernel_shape,
1191
                    self.dtype,
1192
                ).astype(input_dtype)
1193

1194
1195
1196
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1197
                    (num_activations, self.intermediate_dim),
1198
1199
1200
1201
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1202

1203
            if self.use_bias:
1204
                x += jnp.reshape(bias_1, bias_1_shape)
1205

1206
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1207
            if is_act_implemented:
1208
                z = activation(x, normalized_acts)
1209
            else:
1210
                activations = []
1211
                x = jnp.split(x, num_activations, axis=-2)
1212
                for idx, act_fn in enumerate(normalized_acts):
1213
1214
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1215
                z = reduce(operator.mul, activations)
1216
                z = jnp.squeeze(z, axis=-2)
1217
            z = z.astype(input_dtype)
1218

1219
1220
1221
1222
1223
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1224

1225
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1226
            z = z.astype(input_dtype)
1227

1228
            # DenseGeneral 2
1229
            out = dense(
1230
1231
1232
1233
1234
1235
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1236
            )
1237

1238
1239
1240
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1241
                wo_lora_a_kernel = self.param(
1242
                    "wo_lora_a_kernel",
1243
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1244
                    wo_lora_a_kernel_shape,
1245
                    self.dtype,
1246
                ).astype(input_dtype)
1247
1248
1249

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1250
                wo_lora_b_kernel = self.param(
1251
                    "wo_lora_b_kernel",
1252
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1253
                    wo_lora_b_kernel_shape,
1254
                    self.dtype,
1255
                ).astype(input_dtype)
1256

1257
1258
1259
1260
1261
1262
1263
1264
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1265

1266
            if self.use_bias:
1267
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1268

1269
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1270

1271
        assert out.dtype == input_dtype
1272
        return out, ln_output  # Output, layner_norm_output