module.py 50.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
9
10
11
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18
19
20
21
22
23
24
from ..dense import dense

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
28
29
30
31
32
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
33
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34
from ..sharding import get_non_contracting_logical_axes
35
36
37
38
39

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
40
41
42
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
43
44
45
46
47
48
49
50
51
52
53
54
55
56
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


57
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
58
59
60
61
62
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
63
64
65
    return nn.initializers.zeros


66
def _create_layernorm_parameters(
67
    module,
68
69
70
71
72
73
74
75
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
76
):
77
78
79
80
81
82
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
83

84
85
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
86
87
88
89
90
91
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
92
    else:
93
        assert norm_type == "rmsnorm"
94
95
96
97
98
99
100
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
101
    if fn_or_string == "linear":
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
116
117
118
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
119
120
121
122
123
124
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


125
126
127
128
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
129
    hidden_in_names = "ijklm"[: len(axis)]
130
    assert len(features) <= 5
131
132
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
133
134
135
136
137
138
139
140
141

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
142
143
144
145
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
146
147
148
149
150
151

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


152
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
153
154
    r"""
    Applies softmax over a mini-batch of inputs.
155
156
157
158
159
160
161
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
162
163
164
165

    Parameters
    ----------
    scale_factor : float, default = 1.0
166
167
168
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
169
170
171
172
173
174
175
176
177
178
179
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
180
        input_dtype = inputs.dtype
181
182
        logits = inputs

183
184
        # use primitives
        if is_softmax_kernel_available(
185
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
186
        ):
187
            if bias is not None:
188
                logits = logits + bias.astype(input_dtype)
189
190
191
192
193

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

194
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
195
        # use default jax based implementation
196
197
        else:
            if bias is not None:
198
                logits = logits + bias.astype(input_dtype)
199

200
201
202
203
204
205
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
206
            else:
207
208
209
210
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
211
        assert input_dtype == outputs.dtype
212
213
214
        return outputs


215
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
245
        A value added to the denominator of layer normalization for numerical stability.
246
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
247
        Indicate the type of layer normalization.
248
249
250
251
252
253
254
255
256
257
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
258
        Used for initializing scale factors :math:`\gamma`.
259
260
261
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
262
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
263
    scale_axes : Tuple[str, ...], default = ('embed', )
264
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
265
    bias_init : Initializer, default = flax.linen.initializers.zeros
266
267
268
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
269
270
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
271
        only used when :attr:`layernorm_type='layernorm'`.
272
273
274

    Optimization parameters
    -----------------------
275
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
276
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
    transpose_batch_sequence : bool, default = False
278
279
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
280
281
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
282

283
    epsilon: float = 1e-6
284
    layernorm_type: str = "layernorm"
285
286
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
287
    scale_axes: Tuple[str, ...] = ("embed",)
288
    bias_init: Initializer = nn.initializers.zeros
289
    bias_axes: Tuple[str, ...] = ("embed",)
290
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
291
    transpose_batch_sequence: bool = False
292

293
    def __post_init__(self):
294
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
295
296
            self.scale_init,
            self.zero_centered_gamma,
297
        )
298
299
        super().__post_init__()

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
315
        input_dtype = x.dtype
316
317

        features = x.shape[-1]
318
        scale, ln_bias = _create_layernorm_parameters(
319
            self,
320
321
322
323
324
325
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
326
            input_dtype,
327
328
            self.dtype,
        )
329
        out = layernorm(
330
331
332
            x,
            scale,
            ln_bias,
333
            norm_type=self.layernorm_type,
334
335
336
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
337
338
        assert out.dtype == input_dtype
        return out
339
340
341


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
342
343
344
345
    """
    Base class of transformer engine
    """

346
    def generate_quantizer_set(self, postfix: str = ""):
347
        """
348
        Generate a set of FP8 meta for a GEMM.
349
350
        """

351
352
353
354
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
355
356
                jnp.ones,
                (1,),
357
                jnp.float32,
358
359
360
361
362
363
364
365
366
367
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

368
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
369
370
371
372
373
374
375
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
376

377
378
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
379
380
381


class DenseGeneral(TransformerEngineBase):
382
    r"""
383
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
384
385
386
387

    Parameters
    ----------
    features : Union[Iterable[int], int]
388
        The hidden size of each output sample.
389
390
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
391
392
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
393
    kernel_axes : Tuple[str, ...], default = ()
394
        The name of axes used to shard the weights with a corresponding mesh.
395
    use_bias: bool, default = False
396
397
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
398
    bias_init: Initializer, default = flax.linen.initializers.zeros
399
400
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
401
    bias_axes: Tuple[str, ...], default = ()
402
403
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
404
    enable_low_rank_adaptation: bool, default = False
405
        Indicate whether to enable low rank adaptation for each dense layer.
406
407
408
409
410
411
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
412
    axis:  Union[Iterable[int], int], default = -1
413
        An integer tuple with axes to apply the transformation on.
414
415
416
417
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
418
419
420

    Optimization parameters
    -----------------------
421
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
422
        The data type used to allocate the initial parameters.
423
    transpose_batch_sequence : bool, default = True
424
425
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
426
427
428
429
430
431
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
432
    use_bias: bool = True
433
434
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
435
436
437
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
438
439
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
440
    transpose_batch_sequence: bool = False
441
    input_axes: Tuple[str, ...] = ()
442
443
444

    def __post_init__(self):
        if self.kernel_init is None:
445
            self.kernel_init = nn.initializers.variance_scaling(
446
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
447
            )
448
449
450
451
452
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
453
        Apply the dense layer transformation to the input.
454
455
456
457
458
459
460
461
462
463
464

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
465

466
        input_dtype = inputs.dtype
467
468
469
470
471
472
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
473
474
475
476
477
478

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
479
480
481
482
483
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
484
        )
485

486
        if not QuantizeConfig.is_fp8_enabled():
487
            kernel = kernel.astype(input_dtype)
488
489

        if self.use_bias:
490
491
492
493
494
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
495
            ).astype(input_dtype)
496
497
498
        else:
            bias = None

499
        quantizer_set = self.generate_quantizer_set()
500
        contract_ind = tuple(range(0, len(axis)))
501
        y = dense(
502
503
504
505
506
507
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
508
        )
509

510
        if self.enable_low_rank_adaptation:
511
512
513
514
515
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
516
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
517
            lora_a_kernel = self.param(
518
                "lora_a_kernel",
519
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
520
                lora_a_kernel_shape,
521
                self.dtype,
522
            ).astype(input_dtype)
523
524
525

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
526
            lora_b_kernel = self.param(
527
                "lora_b_kernel",
528
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
529
                lora_b_kernel_shape,
530
                self.dtype,
531
            ).astype(input_dtype)
532

533
534
535
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
536

537
        if bias is not None:
538
539
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
540
541

        assert y.dtype == input_dtype
542
543
544
545
546
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
547
    Applies layer normalization followed by dense layer transformation to the incoming data.
548
549
550
551

    Parameters
    ----------
    features : Union[Iterable[int], int]
552
        The hidden size of each output sample.
553
    enable_layernorm: bool, default = True
554
        Indicate whether to enable layer normalization before dense layer transformation.
555
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
556
        Indicate the type of layer normalization.
557
    epsilon : float, default = 1e-6
558
        A value added to the denominator of layer normalization for numerical stability.
559
560
561
562
563
564
565
566
567
568
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
569
        Used for initializing scale factors :math:`\gamma`.
570
571
572
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
573
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
574
    scale_axes : Tuple[str, ...], default = ('embed', )
575
576
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
577
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
578
579
580
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
581
582
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
583
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
584
585
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
586
587
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
588
    kernel_axes : Tuple[str, ...], default = ()
589
        The name of axes used to shard the weights with a corresponding mesh.
590
    use_bias: bool, default = False
591
592
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
593
    bias_init: Initializer, default = flax.linen.initializers.zeros
594
595
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
596
    bias_axes: Tuple[str, ...], default = ()
597
598
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
599
    return_layernorm_output: bool, default = True
600
        Indicate whether to return the output of layer normalization.
601
        If set False, return None as the second tensor in outputs.
602
    enable_low_rank_adaptation: bool, default = False
603
        Indicate whether to enable low rank adaptation for each dense layer.
604
605
606
607
608
609
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
610
    axis:  Union[Iterable[int], int], default = -1
611
        An integer tuple with axes to apply the transformation on.
612
613
614
615
616
617
618
619
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
620
621
622

    Optimization parameters
    -----------------------
623
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
624
        The data type used to allocate the initial parameters.
625
    transpose_batch_sequence : bool, default = True
626
627
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
628
629
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
630
        The factor to scale the output from `DenseGeneral`. It should be a float
631
632
633
634
635
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
636
    layernorm_type: str = "layernorm"
637
    epsilon: float = 1e-6
638
639
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
640
    scale_axes: Tuple[str, ...] = ("embed",)
641
    ln_bias_init: Initializer = nn.initializers.zeros
642
    ln_bias_axes: Tuple[str, ...] = ("embed",)
643
644
645
646
647
648
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
649
650
651
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
652
653
654
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
655
656
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
657
658
659
660
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
661
            self.kernel_init = nn.initializers.variance_scaling(
662
663
664
                1.0,
                "fan_in",
                "truncated_normal",
665
                dtype=self.dtype,
666
            )
667
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
668
669
            self.scale_init,
            self.zero_centered_gamma,
670
        )
671
        self.quantizer_set = QuantizerFactory.create_set()
672
673
674
675
676
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
677
        Apply layer normalization to the input followed by a dense layer transformation.
678
679
680
681
682
683
684
685
686
687
688
689

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
690
            If :attr:`return_layernorm_output=False`, then this would be None.
691
        """
692
        assert self.axis == -1, "Only support axis = =-1 at this moment"
693

694
        input_dtype = inputs.dtype
695
696
        ln_output = None

697
698
        quantizer_set = self.generate_quantizer_set()

699
        fuse_layernorm = (
700
            QuantizeConfig.is_fp8_enabled()
701
702
703
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
704
705

        if self.enable_layernorm:
706
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
707
            features = inputs.shape[-1]
708
            scale, ln_bias = _create_layernorm_parameters(
709
                self,
710
711
712
713
714
715
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
716
                input_dtype,
717
718
                self.dtype,
            )
719
720

            if not fuse_layernorm:
721
722
723
724
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
725
                    norm_type=self.layernorm_type,
726
727
728
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

744
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
745
746
747
748
749
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
750
        )
751
        if not QuantizeConfig.is_fp8_enabled():
752
            kernel = kernel.astype(input_dtype)
753
754
755

        contract_ind = tuple(range(0, len(axis)))

756
        if fuse_layernorm:
757
            z = layernorm_dense(
758
759
760
761
                y,
                kernel,
                scale,
                ln_bias,
762
                norm_type=self.layernorm_type,
763
764
765
766
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
767
                kernel_axes=self.kernel_axes,
768
                quantizer_set=quantizer_set,
769
            )
770
        else:
771
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
772
773
774
775
776
777
778
779
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
780

781
        if self.enable_low_rank_adaptation:
782
783
784
785
786
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
787
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
788
            lora_a_kernel = self.param(
789
                "lora_a_kernel",
790
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
791
                lora_a_kernel_shape,
792
                self.dtype,
793
            ).astype(input_dtype)
794
795
796

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
797
            lora_b_kernel = self.param(
798
                "lora_b_kernel",
799
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
800
                lora_b_kernel_shape,
801
                self.dtype,
802
            ).astype(input_dtype)
803

804
805
806
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
807

808
809
        bias = None
        if self.use_bias:
810
811
812
813
814
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
815
            ).astype(input_dtype)
816
817

        if bias is not None:
818
819
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
820
821
822
823

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

824
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
825
        # z = z.reshape(*inputs.shape[: self.axis], *features)
826
        return z, ln_output  # dense_output, layer_norm_output
827
828
829
830
831


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
832
    consisting of 2 successive dense layer transformations, separated by given activations.
833
834
835
836

    Parameters
    ----------
    intermediate_dim: int, default = 2048
837
        Intermediate size to which input samples are projected.
838
    enable_layernorm: bool, default = True
839
        Indicate whether to enable layer normalization before dense layer transformation.
840
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
841
        Indicate the type of layer normalization.
842
    epsilon : float, default = 1e-6
843
        A value added to the denominator of layer normalization for numerical stability.
844
845
846
847
848
849
850
851
852
853
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
854
        Used for initializing scale factors :math:`\gamma`.
855
856
857
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
858
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
859
    scale_axes : Tuple[str, ...], default = ('embed', )
860
861
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
862
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
863
864
865
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
866
867
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
868
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
869
870
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
871
        Used for initializing the weights of both dense layer transformations.
872
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
873
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
874
        The name of axes used to shard the weights with a corresponding mesh for
875
        the weight of the first dense layer transformation.
876
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
877
        The name of axes used to shard the weights with a corresponding mesh for
878
        the weight of the second dense layer transformation.
879
    use_bias: bool, default = False
880
881
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
882
    bias_init: Initializer, default = flax.linen.initializers.zeros
883
884
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
885
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
886
        The name of axes used to shard bias with a corresponding mesh  for
887
        the weight of the first dense layer transformation.
888
        Only used when :attr:`use_bias=True`.
889
    bias_axes_2: Tuple[str, ...], default = ('embed',)
890
        The name of axes used to shard bias with a corresponding mesh  for
891
        the weight of the second dense layer transformation.
892
        Only used when :attr:`use_bias=True`.
893
    return_layernorm_output: bool, default = True
894
        Indicate whether to return the output of layer normalization.
895
896
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
897
        The sequence of activation functions to apply after the first dense layer transformation.
898
        Each activation has its own transformation layer.
899
900
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
901
    intermediate_dropout_rate: float, default = 0.1
902
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
903
904
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
905
    enable_low_rank_adaptation: bool, default = False
906
        Indicate whether to enable low rank adaptation for each dense layer.
907
908
909
910
911
912
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
913
    axis:  Union[Iterable[int], int], default = -1
914
        An integer tuple with axes to apply the transformation on.
915
916
917
918
919
920
921
922
923
924
925
926
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
927
928
929

    Optimization parameters
    -----------------------
930
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
931
        The data type used to allocate the initial parameters.
932
    transpose_batch_sequence : bool, default = True
933
934
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
935
936
937
938
939
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
940
    layernorm_type: str = "layernorm"
941
    epsilon: float = 1e-6
942
943
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
944
    scale_axes: Tuple[str, ...] = ("embed",)
945
    ln_bias_init: Initializer = nn.initializers.zeros
946
    ln_bias_axes: Tuple[str, ...] = ("embed",)
947
    kernel_init: Initializer = None
948
949
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
950
951
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
952
953
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
954
    return_layernorm_output: bool = True
955
956
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
957
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
958
    intermediate_hidden_dropout_dims: Sequence[int] = ()
959
960
961
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
962
963
964
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
965
966
967
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
968
969
970

    def __post_init__(self):
        if self.kernel_init is None:
971
            self.kernel_init = nn.initializers.variance_scaling(
972
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
973
            )
974
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
975
976
            self.scale_init,
            self.zero_centered_gamma,
977
        )
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
998
            If :attr:`return_layernorm_output=False`, then this would be None.
999
        """
1000
1001
        assert self.axis == -1, "Only support axis == -1 at this moment"

1002
1003
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1004

1005
        input_dtype = inputs.dtype
1006
1007
        ln_output = None

1008
1009
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1010
        fuse_layernorm = (
1011
            QuantizeConfig.is_fp8_enabled()
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1024
        normalized_acts = []
1025
1026
1027
        for act in self.activations:
            if not isinstance(act, str):
                return False
1028
            normalized_acts.append(act.lower())
1029
        normalized_acts = tuple(
1030
1031
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1032

1033
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1034

1035
1036
1037
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1038
1039
        # LayerNorm
        if self.enable_layernorm:
1040
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1041

1042
1043
            features = inputs.shape[-1]

1044
            scale, ln_bias = _create_layernorm_parameters(
1045
                self,
1046
1047
1048
1049
1050
1051
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1052
                input_dtype,
1053
1054
                self.dtype,
            )
1055
1056

            if not fuse_layernorm:
1057
1058
1059
1060
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1061
                    norm_type=self.layernorm_type,
1062
1063
1064
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1079
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1080

1081
        num_activations = len(normalized_acts)
1082
1083
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1084
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1085
        kernel_1 = self.param(
1086
            "wi_kernel",
1087
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1088
1089
1090
            num_activations,
            -2,
            kernel_1_each_shape,
1091
            self.dtype,
1092
        )
1093

1094
        if not QuantizeConfig.is_fp8_enabled():
1095
            kernel_1 = kernel_1.astype(input_dtype)
1096

1097
1098
1099
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1100
        kernel_2 = self.param(
1101
            "wo_kernel",
1102
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1103
            kernel_2_shape,
1104
            self.dtype,
1105
        )
1106
        if not QuantizeConfig.is_fp8_enabled():
1107
            kernel_2 = kernel_2.astype(input_dtype)
1108

1109
        contract_ind = tuple(range(0, len(axis)))
1110

1111
        if self.use_bias:
1112
            bias_1_shape = (num_activations, self.intermediate_dim)
1113
            bias_1 = self.param(
1114
                "wi_bias",
1115
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1116
1117
                bias_1_shape,
                self.dtype,
1118
            ).astype(input_dtype)
1119
1120

            bias_2_shape = (hidden_size,)
1121
            bias_2 = self.param(
1122
                "wo_bias",
1123
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1124
1125
                bias_2_shape,
                self.dtype,
1126
            ).astype(input_dtype)
1127
1128
1129
1130
        else:
            bias_1 = None
            bias_2 = None

1131
1132
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1133

1134
        if use_fused_layernorm_mlp:
1135
            out = layernorm_mlp(
1136
1137
1138
1139
1140
1141
1142
1143
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1144
                norm_input_axes=self.layernorm_input_axes,
1145
1146
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1147
1148
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1149
1150
1151
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
1152
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1153
            )
1154
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1155
1156

        else:  # not use_fused_ln_geglu_mlp
1157
1158
            # DenseGeneral 1
            if fuse_layernorm:
1159
                x = layernorm_dense(
1160
1161
1162
1163
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1164
                    norm_type=self.layernorm_type,
1165
1166
1167
1168
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1169
                    kernel_axes=self.kernel_axes_1,
1170
                    quantizer_set=ffn1_quantizer_set,
1171
                )
1172
            else:
1173
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1174
1175
1176
1177
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1178
1179
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1180
                    quantizer_set=ffn1_quantizer_set,
1181
                )
1182
1183
1184
1185
1186
1187
1188
1189
1190

            if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
                dot_1_output_axes = (
                    *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
                    *get_non_contracting_logical_axes(
                        kernel_1.ndim, self.kernel_axes_1, contract_ind
                    ),
                )
                x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
1191

1192
            if self.enable_low_rank_adaptation:
1193
1194
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1195
1196
                    self.low_rank_adaptation_dim,
                )
1197
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1198
                wi_lora_a_kernel = self.param(
1199
                    "wi_lora_a_kernel",
1200
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1201
                    num_activations,
1202
1203
                    -2,
                    wi_lora_a_kernel_each_shape,
1204
                    self.dtype,
1205
                ).astype(input_dtype)
1206

1207
1208
1209
1210
1211
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1212
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1213
                wi_lora_b_kernel = self.param(
1214
                    "wi_lora_b_kernel",
1215
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1216
                    wi_lora_b_kernel_shape,
1217
                    self.dtype,
1218
                ).astype(input_dtype)
1219

1220
1221
1222
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1223
                    (num_activations, self.intermediate_dim),
1224
1225
1226
1227
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1228

1229
            if self.use_bias:
1230
                x += jnp.reshape(bias_1, bias_1_shape)
1231

1232
            x = checkpoint_name(x, ffn1_ckpt_name)
1233
            if is_act_implemented:
1234
                z = activation(x, normalized_acts)
1235
            else:
1236
                activations = []
1237
                x = jnp.split(x, num_activations, axis=-2)
1238
                for idx, act_fn in enumerate(normalized_acts):
1239
1240
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1241
                z = reduce(operator.mul, activations)
1242
                z = jnp.squeeze(z, axis=-2)
1243
            z = z.astype(input_dtype)
1244

1245
1246
1247
1248
1249
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1250

1251
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1252
            z = z.astype(input_dtype)
1253

1254
            # DenseGeneral 2
1255
            out = dense(
1256
1257
1258
1259
1260
1261
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1262
            )
1263

1264
1265
1266
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1267
                wo_lora_a_kernel = self.param(
1268
                    "wo_lora_a_kernel",
1269
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1270
                    wo_lora_a_kernel_shape,
1271
                    self.dtype,
1272
                ).astype(input_dtype)
1273
1274
1275

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1276
                wo_lora_b_kernel = self.param(
1277
                    "wo_lora_b_kernel",
1278
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1279
                    wo_lora_b_kernel_shape,
1280
                    self.dtype,
1281
                ).astype(input_dtype)
1282

1283
1284
1285
1286
1287
1288
1289
1290
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1291

1292
            if self.use_bias:
1293
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1294

1295
            out = checkpoint_name(out, ffn2_ckpt_name)
1296

1297
        assert out.dtype == input_dtype
1298
        return out, ln_output  # Output, layner_norm_output