module.py 50.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
9
10
11
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import numpy as np
12
import jax.numpy as jnp
13
14
15
16
17
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
18
from jax.ad_checkpoint import checkpoint_name
19

20
21
22
23
24
25
26
from ..dense import dense

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
27
from ..softmax import softmax, SoftmaxType
28
from ..sharding import with_sharding_constraint_by_logical_axes
29
from ..cpp_extensions import is_softmax_kernel_available
30
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
31
from ..sharding import get_non_contracting_logical_axes
32
33
34
35
36

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
37
38
39
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
40
41
42
43
44
45
46
47
48
49
50
51
52
53
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


54
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
55
56
57
58
59
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
60
61
62
    return nn.initializers.zeros


63
def _create_layernorm_parameters(
64
65
66
67
68
69
70
71
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
72
):
73
74
    scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
    scale = scale.astype(input_dtype)
75

76
77
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
78
        bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
79
        bias = jnp.asarray(bias, input_dtype)
80
    else:
81
        assert norm_type == "rmsnorm"
82
83
84
85
86
87
88
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
89
    if fn_or_string == "linear":
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
104
105
106
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
107
108
109
110
111
112
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


113
114
115
116
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
117
    hidden_in_names = "ijklm"[: len(axis)]
118
    assert len(features) <= 5
119
120
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
121
122
123
124
125
126
127
128
129

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
130
131
132
133
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
134
135
136
137
138
139

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


140
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
141
142
    r"""
    Applies softmax over a mini-batch of inputs.
143
144
145
146
147
148
149
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
150
151
152
153

    Parameters
    ----------
    scale_factor : float, default = 1.0
154
155
156
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
157
158
159
160
161
162
163
164
165
166
167
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
168
        input_dtype = inputs.dtype
169
170
        logits = inputs

171
        if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
172
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
173
        ):
174
175

            if bias is not None:
176
                logits = logits + bias.astype(input_dtype)
177
178
179
180
181

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

182
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
183
184
185
        else:
            attention_bias = None
            if mask is not None:
186
187
                attention_bias = lax.select(
                    mask > 0,
188
189
                    jnp.full(mask.shape, -1e10),
                    jnp.full(mask.shape, 0.0),
190
                )
191
                attention_bias = attention_bias.astype(input_dtype)
192
193
194
195
196

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
197
                logits = logits + attention_bias.astype(input_dtype)
198
199
200

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
201
            if is_softmax_kernel_available(
202
                SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype
203
            ):
204
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
205
            else:
206
                outputs = jax_nn.softmax(logits * self.scale_factor)
207

208
        assert input_dtype == outputs.dtype
209
210
211
        return outputs


212
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
242
        A value added to the denominator of layer normalization for numerical stability.
243
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
244
        Indicate the type of layer normalization.
245
246
247
248
249
250
251
252
253
254
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
255
        Used for initializing scale factors :math:`\gamma`.
256
257
258
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
259
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
260
    scale_axes : Tuple[str, ...], default = ('embed', )
261
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
262
    bias_init : Initializer, default = flax.linen.initializers.zeros
263
264
265
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
266
267
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
268
        only used when :attr:`layernorm_type='layernorm'`.
269
270
271

    Optimization parameters
    -----------------------
272
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
273
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
274
    transpose_batch_sequence : bool, default = False
275
276
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
277
278
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
279

280
    epsilon: float = 1e-6
281
    layernorm_type: str = "layernorm"
282
283
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
284
    scale_axes: Tuple[str, ...] = ("embed",)
285
    bias_init: Initializer = nn.initializers.zeros
286
    bias_axes: Tuple[str, ...] = ("embed",)
287
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
288
    transpose_batch_sequence: bool = False
289

290
    def __post_init__(self):
291
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
292
293
            self.scale_init,
            self.zero_centered_gamma,
294
        )
295
296
        super().__post_init__()

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
312
        input_dtype = x.dtype
313
314

        features = x.shape[-1]
315
316
317
318
319
320
321
        scale, ln_bias = _create_layernorm_parameters(
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
322
            input_dtype,
323
324
            self.dtype,
        )
325
        out = layernorm(
326
327
328
            x,
            scale,
            ln_bias,
329
            norm_type=self.layernorm_type,
330
331
332
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
333
334
        assert out.dtype == input_dtype
        return out
335
336
337


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
338
339
340
341
    """
    Base class of transformer engine
    """

342
    def generate_quantizer_set(self, postfix: str = ""):
343
        """
344
        Generate a set of FP8 meta for a GEMM.
345
346
        """

347
348
349
350
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
351
352
                jnp.ones,
                (1,),
353
                jnp.float32,
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

        if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
372

373
374
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
375
376
377


class DenseGeneral(TransformerEngineBase):
378
    r"""
379
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
380
381
382
383

    Parameters
    ----------
    features : Union[Iterable[int], int]
384
        The hidden size of each output sample.
385
386
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
387
388
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
389
    kernel_axes : Tuple[str, ...], default = ()
390
        The name of axes used to shard the weights with a corresponding mesh.
391
    use_bias: bool, default = False
392
393
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
394
    bias_init: Initializer, default = flax.linen.initializers.zeros
395
396
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
397
    bias_axes: Tuple[str, ...], default = ()
398
399
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
400
    enable_low_rank_adaptation: bool, default = False
401
        Indicate whether to enable low rank adaptation for each dense layer.
402
403
404
405
406
407
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
408
    axis:  Union[Iterable[int], int], default = -1
409
        An integer tuple with axes to apply the transformation on.
410
411
412
413
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
414
415
416

    Optimization parameters
    -----------------------
417
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
418
        The data type used to allocate the initial parameters.
419
    transpose_batch_sequence : bool, default = True
420
421
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
422
423
424
425
426
427
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
428
    use_bias: bool = True
429
430
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
431
432
433
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
434
435
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
436
    transpose_batch_sequence: bool = False
437
    input_axes: Tuple[str, ...] = ()
438
439
440

    def __post_init__(self):
        if self.kernel_init is None:
441
            self.kernel_init = nn.initializers.variance_scaling(
442
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
443
            )
444
445
446
447
448
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
449
        Apply the dense layer transformation to the input.
450
451
452
453
454
455
456
457
458
459
460

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
461

462
        input_dtype = inputs.dtype
463
464
465
466
467
468
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
469
470
471
472
473
474

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
475
        kernel = nn_partitioning.param_with_axes(
476
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
477
        )
478

479
        if not QuantizeConfig.is_fp8_enabled():
480
            kernel = kernel.astype(input_dtype)
481
482

        if self.use_bias:
483
            bias = nn_partitioning.param_with_axes(
484
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
485
            ).astype(input_dtype)
486
487
488
        else:
            bias = None

489
        quantizer_set = self.generate_quantizer_set()
490
        contract_ind = tuple(range(0, len(axis)))
491
        y = dense(
492
493
494
495
496
497
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
498
        )
499

500
        if self.enable_low_rank_adaptation:
501
502
503
504
505
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
506
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
507
508
509
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
510
                lora_a_kernel_shape,
511
                self.dtype,
512
513
                axes=lora_a_kernel_axes,
            )
514
            lora_a_kernel = lora_a_kernel.astype(input_dtype)
515
516
517

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
518
519
520
521
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
522
                self.dtype,
523
524
                axes=lora_b_kernel_axes,
            )
525
            lora_b_kernel = lora_b_kernel.astype(input_dtype)
526

527
528
529
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
530

531
        if bias is not None:
532
533
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
534
535

        assert y.dtype == input_dtype
536
537
538
539
540
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
541
    Applies layer normalization followed by dense layer transformation to the incoming data.
542
543
544
545

    Parameters
    ----------
    features : Union[Iterable[int], int]
546
        The hidden size of each output sample.
547
    enable_layernorm: bool, default = True
548
        Indicate whether to enable layer normalization before dense layer transformation.
549
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
550
        Indicate the type of layer normalization.
551
    epsilon : float, default = 1e-6
552
        A value added to the denominator of layer normalization for numerical stability.
553
554
555
556
557
558
559
560
561
562
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
563
        Used for initializing scale factors :math:`\gamma`.
564
565
566
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
567
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
568
    scale_axes : Tuple[str, ...], default = ('embed', )
569
570
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
571
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
572
573
574
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
575
576
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
577
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
578
579
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
580
581
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
582
    kernel_axes : Tuple[str, ...], default = ()
583
        The name of axes used to shard the weights with a corresponding mesh.
584
    use_bias: bool, default = False
585
586
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
587
    bias_init: Initializer, default = flax.linen.initializers.zeros
588
589
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
590
    bias_axes: Tuple[str, ...], default = ()
591
592
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
593
    return_layernorm_output: bool, default = True
594
        Indicate whether to return the output of layer normalization.
595
        If set False, return None as the second tensor in outputs.
596
    enable_low_rank_adaptation: bool, default = False
597
        Indicate whether to enable low rank adaptation for each dense layer.
598
599
600
601
602
603
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
604
    axis:  Union[Iterable[int], int], default = -1
605
        An integer tuple with axes to apply the transformation on.
606
607
608
609
610
611
612
613
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
614
615
616

    Optimization parameters
    -----------------------
617
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
618
        The data type used to allocate the initial parameters.
619
    transpose_batch_sequence : bool, default = True
620
621
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
622
623
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
624
        The factor to scale the output from `DenseGeneral`. It should be a float
625
626
627
628
629
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
630
    layernorm_type: str = "layernorm"
631
    epsilon: float = 1e-6
632
633
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
634
    scale_axes: Tuple[str, ...] = ("embed",)
635
    ln_bias_init: Initializer = nn.initializers.zeros
636
    ln_bias_axes: Tuple[str, ...] = ("embed",)
637
638
639
640
641
642
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
643
644
645
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
646
647
648
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
649
650
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
651
652
653
654
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
655
            self.kernel_init = nn.initializers.variance_scaling(
656
657
658
                1.0,
                "fan_in",
                "truncated_normal",
659
                dtype=self.dtype,
660
            )
661
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
662
663
            self.scale_init,
            self.zero_centered_gamma,
664
        )
665
        self.quantizer_set = QuantizerFactory.create_set()
666
667
668
669
670
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
671
        Apply layer normalization to the input followed by a dense layer transformation.
672
673
674
675
676
677
678
679
680
681
682
683

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
684
            If :attr:`return_layernorm_output=False`, then this would be None.
685
        """
686
        assert self.axis == -1, "Only support axis = =-1 at this moment"
687

688
        input_dtype = inputs.dtype
689
690
        ln_output = None

691
692
        quantizer_set = self.generate_quantizer_set()

693
        fuse_layernorm = (
694
            QuantizeConfig.is_fp8_enabled()
695
696
697
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
698
699

        if self.enable_layernorm:
700
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
701
            features = inputs.shape[-1]
702
703
704
705
706
707
708
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
709
                input_dtype,
710
711
                self.dtype,
            )
712
713

            if not fuse_layernorm:
714
715
716
717
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
718
                    norm_type=self.layernorm_type,
719
720
721
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

737
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
738
        kernel = nn_partitioning.param_with_axes(
739
            "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
740
        )
741
        if not QuantizeConfig.is_fp8_enabled():
742
            kernel = kernel.astype(input_dtype)
743
744
745

        contract_ind = tuple(range(0, len(axis)))

746
        if fuse_layernorm:
747
            z = layernorm_dense(
748
749
750
751
                y,
                kernel,
                scale,
                ln_bias,
752
                norm_type=self.layernorm_type,
753
754
755
756
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
757
                kernel_axes=self.kernel_axes,
758
                quantizer_set=quantizer_set,
759
            )
760
        else:
761
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
762
763
764
765
766
767
768
769
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
770

771
        if self.enable_low_rank_adaptation:
772
773
774
775
776
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
777
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
778
779
780
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
781
                lora_a_kernel_shape,
782
                self.dtype,
783
784
                axes=lora_a_kernel_axes,
            )
785
            lora_a_kernel = lora_a_kernel.astype(input_dtype)
786
787
788

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
789
790
791
792
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
793
                self.dtype,
794
795
                axes=lora_b_kernel_axes,
            )
796
            lora_b_kernel = lora_b_kernel.astype(input_dtype)
797

798
799
800
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
801

802
803
        bias = None
        if self.use_bias:
804
            bias = nn_partitioning.param_with_axes(
805
                "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
806
            ).astype(input_dtype)
807
808

        if bias is not None:
809
810
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
811
812
813
814

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

815
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
816
        # z = z.reshape(*inputs.shape[: self.axis], *features)
817
        return z, ln_output  # dense_output, layer_norm_output
818
819
820
821
822


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
823
    consisting of 2 successive dense layer transformations, separated by given activations.
824
825
826
827

    Parameters
    ----------
    intermediate_dim: int, default = 2048
828
        Intermediate size to which input samples are projected.
829
    enable_layernorm: bool, default = True
830
        Indicate whether to enable layer normalization before dense layer transformation.
831
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
832
        Indicate the type of layer normalization.
833
    epsilon : float, default = 1e-6
834
        A value added to the denominator of layer normalization for numerical stability.
835
836
837
838
839
840
841
842
843
844
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
845
        Used for initializing scale factors :math:`\gamma`.
846
847
848
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
849
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
850
    scale_axes : Tuple[str, ...], default = ('embed', )
851
852
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
853
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
854
855
856
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
857
858
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
859
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
860
861
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
862
        Used for initializing the weights of both dense layer transformations.
863
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
864
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
865
        The name of axes used to shard the weights with a corresponding mesh for
866
        the weight of the first dense layer transformation.
867
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
868
        The name of axes used to shard the weights with a corresponding mesh for
869
        the weight of the second dense layer transformation.
870
    use_bias: bool, default = False
871
872
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
873
    bias_init: Initializer, default = flax.linen.initializers.zeros
874
875
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
876
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
877
        The name of axes used to shard bias with a corresponding mesh  for
878
        the weight of the first dense layer transformation.
879
        Only used when :attr:`use_bias=True`.
880
    bias_axes_2: Tuple[str, ...], default = ('embed',)
881
        The name of axes used to shard bias with a corresponding mesh  for
882
        the weight of the second dense layer transformation.
883
        Only used when :attr:`use_bias=True`.
884
    return_layernorm_output: bool, default = True
885
        Indicate whether to return the output of layer normalization.
886
887
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
888
        The sequence of activation functions to apply after the first dense layer transformation.
889
        Each activation has its own transformation layer.
890
891
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
892
    intermediate_dropout_rate: float, default = 0.1
893
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
894
895
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
896
    enable_low_rank_adaptation: bool, default = False
897
        Indicate whether to enable low rank adaptation for each dense layer.
898
899
900
901
902
903
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
904
    axis:  Union[Iterable[int], int], default = -1
905
        An integer tuple with axes to apply the transformation on.
906
907
908
909
910
911
912
913
914
915
916
917
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
918
919
920

    Optimization parameters
    -----------------------
921
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
922
        The data type used to allocate the initial parameters.
923
    transpose_batch_sequence : bool, default = True
924
925
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
926
927
928
929
930
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
931
    layernorm_type: str = "layernorm"
932
    epsilon: float = 1e-6
933
934
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
935
    scale_axes: Tuple[str, ...] = ("embed",)
936
    ln_bias_init: Initializer = nn.initializers.zeros
937
    ln_bias_axes: Tuple[str, ...] = ("embed",)
938
    kernel_init: Initializer = None
939
940
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
941
942
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
943
944
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
945
    return_layernorm_output: bool = True
946
947
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
948
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
949
    intermediate_hidden_dropout_dims: Sequence[int] = ()
950
951
952
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
953
954
955
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
956
957
958
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
959
960
961

    def __post_init__(self):
        if self.kernel_init is None:
962
            self.kernel_init = nn.initializers.variance_scaling(
963
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
964
            )
965
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
966
967
            self.scale_init,
            self.zero_centered_gamma,
968
        )
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
989
            If :attr:`return_layernorm_output=False`, then this would be None.
990
        """
991
992
        assert self.axis == -1, "Only support axis == -1 at this moment"

993
994
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
995

996
        input_dtype = inputs.dtype
997
998
        ln_output = None

999
1000
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1001
        fuse_layernorm = (
1002
            QuantizeConfig.is_fp8_enabled()
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1015
        normalized_acts = []
1016
1017
1018
        for act in self.activations:
            if not isinstance(act, str):
                return False
1019
            normalized_acts.append(act.lower())
1020
        normalized_acts = tuple(
1021
1022
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1023

1024
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1025

1026
1027
1028
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1029
1030
        # LayerNorm
        if self.enable_layernorm:
1031
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1032

1033
1034
            features = inputs.shape[-1]

1035
1036
1037
1038
1039
1040
1041
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1042
                input_dtype,
1043
1044
                self.dtype,
            )
1045
1046

            if not fuse_layernorm:
1047
1048
1049
1050
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1051
                    norm_type=self.layernorm_type,
1052
1053
1054
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1069
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1070

1071
        num_activations = len(normalized_acts)
1072
1073
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1074
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1075
1076
1077
1078
1079
1080
        kernel_1 = nn_partitioning.param_with_axes(
            "wi_kernel",
            kernel_1_init,
            num_activations,
            -2,
            kernel_1_each_shape,
1081
            self.dtype,
1082
1083
            axes=self.kernel_axes_1,
        )
1084

1085
        if not QuantizeConfig.is_fp8_enabled():
1086
            kernel_1 = kernel_1.astype(input_dtype)
1087

1088
1089
1090
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1091
1092
1093
        kernel_2 = nn_partitioning.param_with_axes(
            "wo_kernel",
            self.kernel_init,
1094
            kernel_2_shape,
1095
            self.dtype,
1096
1097
            axes=self.kernel_axes_2,
        )
1098
        if not QuantizeConfig.is_fp8_enabled():
1099
            kernel_2 = kernel_2.astype(input_dtype)
1100

1101
        contract_ind = tuple(range(0, len(axis)))
1102

1103
        if self.use_bias:
1104
            bias_1_shape = (num_activations, self.intermediate_dim)
1105
1106
1107
1108
1109
1110
            bias_1 = nn_partitioning.param_with_axes(
                "wi_bias",
                self.bias_init,
                bias_1_shape,
                self.dtype,
                axes=self.bias_axes_1,
1111
            ).astype(input_dtype)
1112
1113
1114
1115
1116
1117
1118
1119

            bias_2_shape = (hidden_size,)
            bias_2 = nn_partitioning.param_with_axes(
                "wo_bias",
                self.bias_init,
                bias_2_shape,
                self.dtype,
                axes=self.bias_axes_2,
1120
            ).astype(input_dtype)
1121
1122
1123
1124
        else:
            bias_1 = None
            bias_2 = None

1125
1126
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1127

1128
        if use_fused_layernorm_mlp:
1129
            out = layernorm_mlp(
1130
1131
1132
1133
1134
1135
1136
1137
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1138
                norm_input_axes=self.layernorm_input_axes,
1139
1140
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1141
1142
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1143
1144
1145
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
1146
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1147
            )
1148
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1149
1150

        else:  # not use_fused_ln_geglu_mlp
1151
1152
            # DenseGeneral 1
            if fuse_layernorm:
1153
                x = layernorm_dense(
1154
1155
1156
1157
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1158
                    norm_type=self.layernorm_type,
1159
1160
1161
1162
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1163
                    kernel_axes=self.kernel_axes_1,
1164
                    quantizer_set=ffn1_quantizer_set,
1165
                )
1166
            else:
1167
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1168
1169
1170
1171
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1172
1173
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1174
                    quantizer_set=ffn1_quantizer_set,
1175
                )
1176
1177
1178
1179
1180
            dot_1_output_axes = (
                *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
                *get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
            )
            x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
1181

1182
            if self.enable_low_rank_adaptation:
1183
1184
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1185
1186
                    self.low_rank_adaptation_dim,
                )
1187
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1188
1189
1190
1191
                wi_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_a_kernel",
                    kernel_1_init,
                    num_activations,
1192
1193
                    -2,
                    wi_lora_a_kernel_each_shape,
1194
                    self.dtype,
1195
1196
                    axes=wi_lora_a_kernel_axes,
                )
1197
                wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
1198

1199
1200
1201
1202
1203
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1204
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1205
1206
1207
1208
                wi_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_b_kernel",
                    nn.initializers.zeros,
                    wi_lora_b_kernel_shape,
1209
                    self.dtype,
1210
1211
                    axes=wi_lora_b_kernel_axes,
                )
1212
                wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
1213

1214
1215
1216
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1217
                    (num_activations, self.intermediate_dim),
1218
1219
1220
1221
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1222

1223
            if self.use_bias:
1224
                x += jnp.reshape(bias_1, bias_1_shape)
1225

1226
            x = checkpoint_name(x, ffn1_ckpt_name)
1227
            if is_act_implemented:
1228
                z = activation(x, normalized_acts)
1229
            else:
1230
                activations = []
1231
                x = jnp.split(x, num_activations, axis=-2)
1232
                for idx, act_fn in enumerate(normalized_acts):
1233
1234
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1235
                z = reduce(operator.mul, activations)
1236
                z = jnp.squeeze(z, axis=-2)
1237
            z = z.astype(input_dtype)
1238

1239
1240
1241
1242
1243
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1244

1245
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1246
            z = z.astype(input_dtype)
1247

1248
            # DenseGeneral 2
1249
            out = dense(
1250
1251
1252
1253
1254
1255
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1256
            )
1257

1258
1259
1260
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1261
1262
1263
1264
                wo_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_a_kernel",
                    self.kernel_init,
                    wo_lora_a_kernel_shape,
1265
                    self.dtype,
1266
1267
                    axes=wo_lora_a_kernel_axes,
                )
1268
                wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
1269
1270
1271

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1272
1273
1274
1275
                wo_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_b_kernel",
                    nn.initializers.zeros,
                    wo_lora_b_kernel_shape,
1276
                    self.dtype,
1277
1278
                    axes=wo_lora_b_kernel_axes,
                )
1279
                wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
1280

1281
1282
1283
1284
1285
1286
1287
1288
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1289

1290
            if self.use_bias:
1291
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1292

1293
            out = checkpoint_name(out, ffn2_ckpt_name)
1294

1295
        assert out.dtype == input_dtype
1296
        return out, ln_output  # Output, layner_norm_output