module.py 51.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

Alp Dener's avatar
Alp Dener committed
18
from ..dense import dense, _issue_batch_first_warning as _dense_warning
19
20
21

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
Alp Dener's avatar
Alp Dener committed
22
23
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
24
from ..activation import activation
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
28
29
30
31
32
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
33
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34
from ..sharding import get_non_contracting_logical_axes
35
36
37

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
38
39
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
40
41
42
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
43
44
45
46
47
48
49
50
51
52
53
54
55
56
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


57
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
58
59
60
61
62
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
63
64
65
    return nn.initializers.zeros


66
def _create_layernorm_parameters(
67
    module,
68
69
70
71
72
73
74
75
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
76
):
77
78
79
80
81
82
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
83

84
85
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
86
87
88
89
90
91
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
92
    else:
93
        assert norm_type == "rmsnorm"
94
95
96
97
98
99
100
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
101
    if fn_or_string == "linear":
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
116
117
118
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
119
120
121
122
123
124
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


125
126
127
128
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
129
    hidden_in_names = "ijklm"[: len(axis)]
130
    assert len(features) <= 5
131
132
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
133
134
135
136
137
138
139
140
141

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
142
143
144
145
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
146
147
148
149
150
151

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


152
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
153
154
    r"""
    Applies softmax over a mini-batch of inputs.
155
156
157
158
159
160
161
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
162
163
164
165

    Parameters
    ----------
    scale_factor : float, default = 1.0
166
167
168
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
169
170
171
172
173
174
175
176
177
178
179
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
180
        input_dtype = inputs.dtype
181
182
        logits = inputs

183
184
        # use primitives
        if is_softmax_kernel_available(
185
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
186
        ):
187
            if bias is not None:
188
                logits = logits + bias.astype(input_dtype)
189
190
191
192
193

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

194
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
195
        # use default jax based implementation
196
197
        else:
            if bias is not None:
198
                logits = logits + bias.astype(input_dtype)
199

200
201
202
203
204
205
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
206
            else:
207
208
209
210
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
211
        assert input_dtype == outputs.dtype
212
213
214
        return outputs


215
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
245
        A value added to the denominator of layer normalization for numerical stability.
246
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
247
        Indicate the type of layer normalization.
248
249
250
251
252
253
254
255
256
257
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
258
        Used for initializing scale factors :math:`\gamma`.
259
260
261
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
262
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
263
    scale_axes : Tuple[str, ...], default = ('embed', )
264
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
265
    bias_init : Initializer, default = flax.linen.initializers.zeros
266
267
268
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
269
270
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
271
        only used when :attr:`layernorm_type='layernorm'`.
272
273
274

    Optimization parameters
    -----------------------
275
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
276
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
    transpose_batch_sequence : bool, default = False
278
279
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
280
281
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
282

283
    epsilon: float = 1e-6
284
    layernorm_type: str = "layernorm"
285
286
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
287
    scale_axes: Tuple[str, ...] = ("embed",)
288
    bias_init: Initializer = nn.initializers.zeros
289
    bias_axes: Tuple[str, ...] = ("embed",)
290
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
291
    transpose_batch_sequence: bool = False
292

293
    def __post_init__(self):
294
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
295
296
            self.scale_init,
            self.zero_centered_gamma,
297
        )
298
299
        super().__post_init__()

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
315
        input_dtype = x.dtype
316
317

        features = x.shape[-1]
318
        scale, ln_bias = _create_layernorm_parameters(
319
            self,
320
321
322
323
324
325
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
326
            input_dtype,
327
328
            self.dtype,
        )
329
        out = layernorm(
330
331
332
            x,
            scale,
            ln_bias,
333
            norm_type=self.layernorm_type,
334
335
336
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
337
338
        assert out.dtype == input_dtype
        return out
339
340
341


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
342
343
344
345
    """
    Base class of transformer engine
    """

346
    def generate_quantizer_set(self, postfix: str = ""):
347
        """
348
        Generate a set of FP8 meta for a GEMM.
349
350
        """

351
352
353
354
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
355
356
                jnp.ones,
                (1,),
357
                jnp.float32,
358
359
360
361
362
363
364
365
366
367
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

368
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
369
370
371
372
373
374
375
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
376

377
378
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
379
380
381


class DenseGeneral(TransformerEngineBase):
382
    r"""
383
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
384
385
386
387

    Parameters
    ----------
    features : Union[Iterable[int], int]
388
        The hidden size of each output sample.
389
390
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
391
392
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
393
    kernel_axes : Tuple[str, ...], default = ()
394
        The name of axes used to shard the weights with a corresponding mesh.
395
    use_bias: bool, default = False
396
397
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
398
    bias_init: Initializer, default = flax.linen.initializers.zeros
399
400
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
401
    bias_axes: Tuple[str, ...], default = ()
402
403
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
404
    enable_low_rank_adaptation: bool, default = False
405
        Indicate whether to enable low rank adaptation for each dense layer.
406
407
408
409
410
411
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
412
    axis:  Union[Iterable[int], int], default = -1
413
        An integer tuple with axes to apply the transformation on.
414
415
416
417
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
418
419
420

    Optimization parameters
    -----------------------
421
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
422
        The data type used to allocate the initial parameters.
423
    transpose_batch_sequence : bool, default = True
424
425
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
426
427
428
429
430
431
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
432
    use_bias: bool = True
433
434
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
435
436
437
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
438
439
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
440
    transpose_batch_sequence: bool = False
441
    input_axes: Tuple[str, ...] = ()
442
443

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
444
445
446
447
448
449
        if self.transpose_batch_sequence:
            _dense_warning(
                "TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
450
        if self.kernel_init is None:
451
            self.kernel_init = nn.initializers.variance_scaling(
452
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
453
            )
454
455
456
457
458
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
459
        Apply the dense layer transformation to the input.
460
461
462
463
464
465
466
467
468
469
470

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
471

472
        input_dtype = inputs.dtype
473
474
475
476
477
478
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
479
480
481
482
483
484

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
485
486
487
488
489
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
490
        )
491

492
        if not QuantizeConfig.is_fp8_enabled():
493
            kernel = kernel.astype(input_dtype)
494
495

        if self.use_bias:
496
497
498
499
500
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
501
            ).astype(input_dtype)
502
503
504
        else:
            bias = None

505
        quantizer_set = self.generate_quantizer_set()
506
        contract_ind = tuple(range(0, len(axis)))
507
        y = dense(
508
509
510
511
512
513
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
514
        )
515

516
        if self.enable_low_rank_adaptation:
517
518
519
520
521
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
522
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
523
            lora_a_kernel = self.param(
524
                "lora_a_kernel",
525
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
526
                lora_a_kernel_shape,
527
                self.dtype,
528
            ).astype(input_dtype)
529
530
531

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
532
            lora_b_kernel = self.param(
533
                "lora_b_kernel",
534
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
535
                lora_b_kernel_shape,
536
                self.dtype,
537
            ).astype(input_dtype)
538

539
540
541
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
542

543
        if bias is not None:
544
545
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
546
547

        assert y.dtype == input_dtype
548
549
550
551
552
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
553
    Applies layer normalization followed by dense layer transformation to the incoming data.
554
555
556
557

    Parameters
    ----------
    features : Union[Iterable[int], int]
558
        The hidden size of each output sample.
559
    enable_layernorm: bool, default = True
560
        Indicate whether to enable layer normalization before dense layer transformation.
561
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
562
        Indicate the type of layer normalization.
563
    epsilon : float, default = 1e-6
564
        A value added to the denominator of layer normalization for numerical stability.
565
566
567
568
569
570
571
572
573
574
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
575
        Used for initializing scale factors :math:`\gamma`.
576
577
578
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
579
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
580
    scale_axes : Tuple[str, ...], default = ('embed', )
581
582
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
583
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
584
585
586
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
587
588
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
589
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
590
591
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
592
593
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
594
    kernel_axes : Tuple[str, ...], default = ()
595
        The name of axes used to shard the weights with a corresponding mesh.
596
    use_bias: bool, default = False
597
598
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
599
    bias_init: Initializer, default = flax.linen.initializers.zeros
600
601
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
602
    bias_axes: Tuple[str, ...], default = ()
603
604
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
605
    return_layernorm_output: bool, default = True
606
        Indicate whether to return the output of layer normalization.
607
        If set False, return None as the second tensor in outputs.
608
    enable_low_rank_adaptation: bool, default = False
609
        Indicate whether to enable low rank adaptation for each dense layer.
610
611
612
613
614
615
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
616
    axis:  Union[Iterable[int], int], default = -1
617
        An integer tuple with axes to apply the transformation on.
618
619
620
621
622
623
624
625
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
626
627
628

    Optimization parameters
    -----------------------
629
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
630
        The data type used to allocate the initial parameters.
631
    transpose_batch_sequence : bool, default = True
632
633
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
634
635
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
636
        The factor to scale the output from `DenseGeneral`. It should be a float
637
638
639
640
641
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
642
    layernorm_type: str = "layernorm"
643
    epsilon: float = 1e-6
644
645
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
646
    scale_axes: Tuple[str, ...] = ("embed",)
647
    ln_bias_init: Initializer = nn.initializers.zeros
648
    ln_bias_axes: Tuple[str, ...] = ("embed",)
649
650
651
652
653
654
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
655
656
657
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
658
659
660
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
661
662
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
663
664
665
    depth_scaling: float = None

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
666
667
668
669
670
671
        if self.transpose_batch_sequence:
            _ln_dense_warning(
                "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
                "inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
                "Use sequence-first inputs at your own discretion."
            )
672
        if self.kernel_init is None:
673
            self.kernel_init = nn.initializers.variance_scaling(
674
675
676
                1.0,
                "fan_in",
                "truncated_normal",
677
                dtype=self.dtype,
678
            )
679
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
680
681
            self.scale_init,
            self.zero_centered_gamma,
682
        )
683
        self.quantizer_set = QuantizerFactory.create_set()
684
685
686
687
688
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
689
        Apply layer normalization to the input followed by a dense layer transformation.
690
691
692
693
694
695
696
697
698
699
700
701

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
702
            If :attr:`return_layernorm_output=False`, then this would be None.
703
        """
704
        assert self.axis == -1, "Only support axis = =-1 at this moment"
705

706
        input_dtype = inputs.dtype
707
708
        ln_output = None

709
710
        quantizer_set = self.generate_quantizer_set()

711
        fuse_layernorm = (
712
            QuantizeConfig.is_fp8_enabled()
713
714
715
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
716
717

        if self.enable_layernorm:
718
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
719
            features = inputs.shape[-1]
720
            scale, ln_bias = _create_layernorm_parameters(
721
                self,
722
723
724
725
726
727
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
728
                input_dtype,
729
730
                self.dtype,
            )
731
732

            if not fuse_layernorm:
733
734
735
736
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
737
                    norm_type=self.layernorm_type,
738
739
740
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

756
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
757
758
759
760
761
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
762
        )
763
        if not QuantizeConfig.is_fp8_enabled():
764
            kernel = kernel.astype(input_dtype)
765
766
767

        contract_ind = tuple(range(0, len(axis)))

768
        if fuse_layernorm:
769
            z = layernorm_dense(
770
771
772
773
                y,
                kernel,
                scale,
                ln_bias,
774
                norm_type=self.layernorm_type,
775
776
777
778
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
779
                kernel_axes=self.kernel_axes,
780
                quantizer_set=quantizer_set,
781
            )
782
        else:
783
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
784
785
786
787
788
789
790
791
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
792

793
        if self.enable_low_rank_adaptation:
794
795
796
797
798
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
799
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
800
            lora_a_kernel = self.param(
801
                "lora_a_kernel",
802
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
803
                lora_a_kernel_shape,
804
                self.dtype,
805
            ).astype(input_dtype)
806
807
808

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
809
            lora_b_kernel = self.param(
810
                "lora_b_kernel",
811
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
812
                lora_b_kernel_shape,
813
                self.dtype,
814
            ).astype(input_dtype)
815

816
817
818
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
819

820
821
        bias = None
        if self.use_bias:
822
823
824
825
826
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
827
            ).astype(input_dtype)
828
829

        if bias is not None:
830
831
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
832
833
834
835

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

836
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
837
        # z = z.reshape(*inputs.shape[: self.axis], *features)
838
        return z, ln_output  # dense_output, layer_norm_output
839
840
841
842
843


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
844
    consisting of 2 successive dense layer transformations, separated by given activations.
845
846
847
848

    Parameters
    ----------
    intermediate_dim: int, default = 2048
849
        Intermediate size to which input samples are projected.
850
    enable_layernorm: bool, default = True
851
        Indicate whether to enable layer normalization before dense layer transformation.
852
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
853
        Indicate the type of layer normalization.
854
    epsilon : float, default = 1e-6
855
        A value added to the denominator of layer normalization for numerical stability.
856
857
858
859
860
861
862
863
864
865
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
866
        Used for initializing scale factors :math:`\gamma`.
867
868
869
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
870
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
871
    scale_axes : Tuple[str, ...], default = ('embed', )
872
873
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
874
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
875
876
877
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
878
879
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
880
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
881
882
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
883
        Used for initializing the weights of both dense layer transformations.
884
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
885
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
886
        The name of axes used to shard the weights with a corresponding mesh for
887
        the weight of the first dense layer transformation.
888
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
889
        The name of axes used to shard the weights with a corresponding mesh for
890
        the weight of the second dense layer transformation.
891
    use_bias: bool, default = False
892
893
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
894
    bias_init: Initializer, default = flax.linen.initializers.zeros
895
896
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
897
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
898
        The name of axes used to shard bias with a corresponding mesh  for
899
        the weight of the first dense layer transformation.
900
        Only used when :attr:`use_bias=True`.
901
    bias_axes_2: Tuple[str, ...], default = ('embed',)
902
        The name of axes used to shard bias with a corresponding mesh  for
903
        the weight of the second dense layer transformation.
904
        Only used when :attr:`use_bias=True`.
905
    return_layernorm_output: bool, default = True
906
        Indicate whether to return the output of layer normalization.
907
908
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
909
        The sequence of activation functions to apply after the first dense layer transformation.
910
        Each activation has its own transformation layer.
911
912
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
913
    intermediate_dropout_rate: float, default = 0.1
914
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
915
916
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
917
    enable_low_rank_adaptation: bool, default = False
918
        Indicate whether to enable low rank adaptation for each dense layer.
919
920
921
922
923
924
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
925
    axis:  Union[Iterable[int], int], default = -1
926
        An integer tuple with axes to apply the transformation on.
927
928
929
930
931
932
933
934
935
936
937
938
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
939
940
941

    Optimization parameters
    -----------------------
942
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
943
        The data type used to allocate the initial parameters.
944
    transpose_batch_sequence : bool, default = True
945
946
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
947
948
949
950
951
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
952
    layernorm_type: str = "layernorm"
953
    epsilon: float = 1e-6
954
955
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
956
    scale_axes: Tuple[str, ...] = ("embed",)
957
    ln_bias_init: Initializer = nn.initializers.zeros
958
    ln_bias_axes: Tuple[str, ...] = ("embed",)
959
    kernel_init: Initializer = None
960
961
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
962
963
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
964
965
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
966
    return_layernorm_output: bool = True
967
968
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
969
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
970
    intermediate_hidden_dropout_dims: Sequence[int] = ()
971
972
973
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
974
975
976
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
977
978
979
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
980
981

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
982
983
984
985
986
987
        if self.transpose_batch_sequence:
            _ln_mlp_warning(
                "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
988
        if self.kernel_init is None:
989
            self.kernel_init = nn.initializers.variance_scaling(
990
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
991
            )
992
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
993
994
            self.scale_init,
            self.zero_centered_gamma,
995
        )
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1016
            If :attr:`return_layernorm_output=False`, then this would be None.
1017
        """
1018
1019
        assert self.axis == -1, "Only support axis == -1 at this moment"

1020
1021
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1022

1023
        input_dtype = inputs.dtype
1024
1025
        ln_output = None

1026
1027
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1028
        fuse_layernorm = (
1029
            QuantizeConfig.is_fp8_enabled()
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1042
        normalized_acts = []
1043
1044
1045
        for act in self.activations:
            if not isinstance(act, str):
                return False
1046
            normalized_acts.append(act.lower())
1047
        normalized_acts = tuple(
1048
1049
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1050

1051
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1052

1053
1054
1055
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1056
1057
        # LayerNorm
        if self.enable_layernorm:
1058
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1059

1060
1061
            features = inputs.shape[-1]

1062
            scale, ln_bias = _create_layernorm_parameters(
1063
                self,
1064
1065
1066
1067
1068
1069
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1070
                input_dtype,
1071
1072
                self.dtype,
            )
1073
1074

            if not fuse_layernorm:
1075
1076
1077
1078
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1079
                    norm_type=self.layernorm_type,
1080
1081
1082
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1097
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1098

1099
        num_activations = len(normalized_acts)
1100
1101
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1102
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1103
        kernel_1 = self.param(
1104
            "wi_kernel",
1105
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1106
1107
1108
            num_activations,
            -2,
            kernel_1_each_shape,
1109
            self.dtype,
1110
        )
1111

1112
        if not QuantizeConfig.is_fp8_enabled():
1113
            kernel_1 = kernel_1.astype(input_dtype)
1114

1115
1116
1117
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1118
        kernel_2 = self.param(
1119
            "wo_kernel",
1120
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1121
            kernel_2_shape,
1122
            self.dtype,
1123
        )
1124
        if not QuantizeConfig.is_fp8_enabled():
1125
            kernel_2 = kernel_2.astype(input_dtype)
1126

1127
        contract_ind = tuple(range(0, len(axis)))
1128

1129
        if self.use_bias:
1130
            bias_1_shape = (num_activations, self.intermediate_dim)
1131
            bias_1 = self.param(
1132
                "wi_bias",
1133
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1134
1135
                bias_1_shape,
                self.dtype,
1136
            ).astype(input_dtype)
1137
1138

            bias_2_shape = (hidden_size,)
1139
            bias_2 = self.param(
1140
                "wo_bias",
1141
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1142
1143
                bias_2_shape,
                self.dtype,
1144
            ).astype(input_dtype)
1145
1146
1147
1148
        else:
            bias_1 = None
            bias_2 = None

1149
1150
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1151

1152
        if use_fused_layernorm_mlp:
1153
            out = layernorm_mlp(
1154
1155
1156
1157
1158
1159
1160
1161
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1162
                norm_input_axes=self.layernorm_input_axes,
1163
1164
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1165
1166
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1167
1168
1169
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
1170
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1171
            )
1172
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1173
1174

        else:  # not use_fused_ln_geglu_mlp
1175
1176
            # DenseGeneral 1
            if fuse_layernorm:
1177
                x = layernorm_dense(
1178
1179
1180
1181
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1182
                    norm_type=self.layernorm_type,
1183
1184
1185
1186
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1187
                    kernel_axes=self.kernel_axes_1,
1188
                    quantizer_set=ffn1_quantizer_set,
1189
                )
1190
            else:
1191
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1192
1193
1194
1195
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1196
1197
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1198
                    quantizer_set=ffn1_quantizer_set,
1199
                )
1200
1201
1202
1203
1204
1205
1206
1207
1208

            if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
                dot_1_output_axes = (
                    *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
                    *get_non_contracting_logical_axes(
                        kernel_1.ndim, self.kernel_axes_1, contract_ind
                    ),
                )
                x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
1209

1210
            if self.enable_low_rank_adaptation:
1211
1212
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1213
1214
                    self.low_rank_adaptation_dim,
                )
1215
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1216
                wi_lora_a_kernel = self.param(
1217
                    "wi_lora_a_kernel",
1218
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1219
                    num_activations,
1220
1221
                    -2,
                    wi_lora_a_kernel_each_shape,
1222
                    self.dtype,
1223
                ).astype(input_dtype)
1224

1225
1226
1227
1228
1229
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1230
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1231
                wi_lora_b_kernel = self.param(
1232
                    "wi_lora_b_kernel",
1233
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1234
                    wi_lora_b_kernel_shape,
1235
                    self.dtype,
1236
                ).astype(input_dtype)
1237

1238
1239
1240
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1241
                    (num_activations, self.intermediate_dim),
1242
1243
1244
1245
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1246

1247
            if self.use_bias:
1248
                x += jnp.reshape(bias_1, bias_1_shape)
1249

1250
            x = checkpoint_name(x, ffn1_ckpt_name)
1251
            if is_act_implemented:
1252
                z = activation(x, normalized_acts)
1253
            else:
1254
                activations = []
1255
                x = jnp.split(x, num_activations, axis=-2)
1256
                for idx, act_fn in enumerate(normalized_acts):
1257
1258
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1259
                z = reduce(operator.mul, activations)
1260
                z = jnp.squeeze(z, axis=-2)
1261
            z = z.astype(input_dtype)
1262

1263
1264
1265
1266
1267
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1268

1269
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1270
            z = z.astype(input_dtype)
1271

1272
            # DenseGeneral 2
1273
            out = dense(
1274
1275
1276
1277
1278
1279
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1280
            )
1281

1282
1283
1284
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1285
                wo_lora_a_kernel = self.param(
1286
                    "wo_lora_a_kernel",
1287
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1288
                    wo_lora_a_kernel_shape,
1289
                    self.dtype,
1290
                ).astype(input_dtype)
1291
1292
1293

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1294
                wo_lora_b_kernel = self.param(
1295
                    "wo_lora_b_kernel",
1296
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1297
                    wo_lora_b_kernel_shape,
1298
                    self.dtype,
1299
                ).astype(input_dtype)
1300

1301
1302
1303
1304
1305
1306
1307
1308
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1309

1310
            if self.use_bias:
1311
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1312

1313
            out = checkpoint_name(out, ffn2_ckpt_name)
1314

1315
        assert out.dtype == input_dtype
1316
        return out, ln_output  # Output, layner_norm_output