module.py 51.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18

19
from ..dense import dense
20
21
22

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
23
24
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
25
from ..activation import activation
26
from ..softmax import softmax, SoftmaxType
27
from ..sharding import with_sharding_constraint_by_logical_axes
28
29
30
31
32
33
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
34
35
36
37
38
from ..quantize import (
    QuantizerFactory,
    get_quantize_config,
    QuantizeMetaSet,
    TensorSource,
39
    get_quantize_config_with_recipe,
40
)
41
42
43

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
44
45
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
46
47
48
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
49
50
51
52
53
54
55
56
57
58
59
60
61
62
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


63
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
64
65
66
67
68
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
69
70
71
    return nn.initializers.zeros


72
def _create_layernorm_parameters(
73
    module,
74
75
76
77
78
79
80
81
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
82
):
83
84
85
86
87
88
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
89

90
91
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
92
93
94
95
96
97
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
98
    else:
99
        assert norm_type == "rmsnorm"
100
101
102
103
104
105
106
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
107
    if fn_or_string == "linear":
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
122
123
124
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
125
126
127
128
129
130
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


131
132
133
134
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
135
    hidden_in_names = "ijklm"[: len(axis)]
136
    assert len(features) <= 5
137
138
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
139
140
141
142
143
144
145
146
147

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
148
149
150
151
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
152
153
154
155
156
157

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


158
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
159
160
    r"""
    Applies softmax over a mini-batch of inputs.
161
162
163
164
165
166
167
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
168
169
170
171

    Parameters
    ----------
    scale_factor : float, default = 1.0
172
173
174
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
175
176
177
178
179
180
181
182
183
184
185
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
186
        input_dtype = inputs.dtype
187
188
        logits = inputs

189
190
        # use primitives
        if is_softmax_kernel_available(
191
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
192
        ):
193
            if bias is not None:
194
                logits = logits + bias.astype(input_dtype)
195
196
197
198
199

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

200
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
201
        # use default jax based implementation
202
203
        else:
            if bias is not None:
204
                logits = logits + bias.astype(input_dtype)
205

206
207
208
209
210
211
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
212
            else:
213
214
215
216
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
217
        assert input_dtype == outputs.dtype
218
219
220
        return outputs


221
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
251
        A value added to the denominator of layer normalization for numerical stability.
252
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
253
        Indicate the type of layer normalization.
254
255
256
257
258
259
260
261
262
263
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
264
        Used for initializing scale factors :math:`\gamma`.
265
266
267
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
268
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
269
    scale_axes : Tuple[str, ...], default = ('embed', )
270
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
271
    bias_init : Initializer, default = flax.linen.initializers.zeros
272
273
274
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
275
276
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
277
        only used when :attr:`layernorm_type='layernorm'`.
278
279
280

    Optimization parameters
    -----------------------
281
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
282
        The data type used to allocate the initial parameters.
283
    """
284

285
    epsilon: float = 1e-6
286
    layernorm_type: str = "layernorm"
287
288
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
289
    scale_axes: Tuple[str, ...] = ("embed",)
290
    bias_init: Initializer = nn.initializers.zeros
291
    bias_axes: Tuple[str, ...] = ("embed",)
292
293
    dtype: DType = jnp.float32

294
    def __post_init__(self):
295
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
296
297
            self.scale_init,
            self.zero_centered_gamma,
298
        )
299
300
        super().__post_init__()

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
316
        input_dtype = x.dtype
317
318

        features = x.shape[-1]
319
        scale, ln_bias = _create_layernorm_parameters(
320
            self,
321
322
323
324
325
326
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
327
            input_dtype,
328
329
            self.dtype,
        )
330
        out = layernorm(
331
332
333
            x,
            scale,
            ln_bias,
334
            norm_type=self.layernorm_type,
335
336
337
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
338
339
        assert out.dtype == input_dtype
        return out
340
341
342


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
343
344
345
346
    """
    Base class of transformer engine
    """

347
    def generate_quantizer_set(
348
349
350
351
352
        self,
        postfix: str = "",
        variable_collection: str = None,
        quantization_checkpoint_name: Optional[str] = None,
        fp8_recipe=None,
353
    ):
354
        """
355
        Generate a set of FP8 meta for a GEMM.
356
357
        """

358
359
360
361
362
363
364
365
        collection_name = (
            variable_collection
            if variable_collection is not None
            else get_quantize_config().COLLECTION_NAME
        )

        if fp8_recipe is None:
            quantize_config = get_quantize_config()
366
        else:
367
368
369
370
371
372
373
374
375
376
377
            quantize_config = get_quantize_config_with_recipe(fp8_recipe)

        x_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.X, "x"
        )
        kernel_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.KERNEL, "kernel"
        )
        grad_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.DGRAD, "grad"
        )
378

379
380
381
        quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)

        quantizer_set = QuantizerFactory.create_set(
382
383
384
            fp8_recipe=fp8_recipe,
            quantize_meta_set=quantize_meta_set,
            checkpoint_name=quantization_checkpoint_name,
385
        )
386
        return quantizer_set
387
388
389


class DenseGeneral(TransformerEngineBase):
390
    r"""
391
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
392
393
394
395

    Parameters
    ----------
    features : Union[Iterable[int], int]
396
        The hidden size of each output sample.
397
398
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
399
400
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
401
    kernel_axes : Tuple[str, ...], default = ()
402
        The name of axes used to shard the weights with a corresponding mesh.
403
    use_bias: bool, default = False
404
405
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
406
    bias_init: Initializer, default = flax.linen.initializers.zeros
407
408
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
409
    bias_axes: Tuple[str, ...], default = ()
410
411
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
412
    enable_low_rank_adaptation: bool, default = False
413
        Indicate whether to enable low rank adaptation for each dense layer.
414
415
416
417
418
419
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
420
    axis:  Union[Iterable[int], int], default = -1
421
        An integer tuple with axes to apply the transformation on.
422
423
424
425
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
426
427
428

    Optimization parameters
    -----------------------
429
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
430
        The data type used to allocate the initial parameters.
431
432
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
433
434
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
435
436
437
438
439
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
440
    use_bias: bool = True
441
442
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
443
444
445
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
446
447
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
448
    input_axes: Tuple[str, ...] = ()
449
    transpose_batch_sequence: bool = False
450
    quantization_checkpoint_name: Optional[str] = None
451
452
453

    def __post_init__(self):
        if self.kernel_init is None:
454
            self.kernel_init = nn.initializers.variance_scaling(
455
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
456
            )
457
458
459
460
461
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
462
        Apply the dense layer transformation to the input.
463
464
465
466
467
468
469
470
471
472
473

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
474

475
        input_dtype = inputs.dtype
476
477
478
479
480
481
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
482
483
484
485
486
487

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
488
489
490
491
492
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
493
        )
494

495
        if not get_quantize_config().is_fp8_enabled():
496
            kernel = kernel.astype(input_dtype)
497
498

        if self.use_bias:
499
500
501
502
503
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
504
            ).astype(input_dtype)
505
506
507
        else:
            bias = None

508
509
510
        quantizer_set = self.generate_quantizer_set(
            quantization_checkpoint_name=self.quantization_checkpoint_name
        )
511
        contract_ind = tuple(range(0, len(axis)))
512
        y = dense(
513
514
515
516
517
518
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
519
            transpose_batch_sequence=self.transpose_batch_sequence,
520
        )
521

522
        if self.enable_low_rank_adaptation:
523
524
525
526
527
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
528
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
529
            lora_a_kernel = self.param(
530
                "lora_a_kernel",
531
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
532
                lora_a_kernel_shape,
533
                self.dtype,
534
            ).astype(input_dtype)
535
536
537

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
538
            lora_b_kernel = self.param(
539
                "lora_b_kernel",
540
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
541
                lora_b_kernel_shape,
542
                self.dtype,
543
            ).astype(input_dtype)
544

545
546
547
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
548

549
        if bias is not None:
550
551
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
552
553

        assert y.dtype == input_dtype
554
555
556
557
558
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
559
    Applies layer normalization followed by dense layer transformation to the incoming data.
560
561
562
563

    Parameters
    ----------
    features : Union[Iterable[int], int]
564
        The hidden size of each output sample.
565
    enable_layernorm: bool, default = True
566
        Indicate whether to enable layer normalization before dense layer transformation.
567
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
568
        Indicate the type of layer normalization.
569
    epsilon : float, default = 1e-6
570
        A value added to the denominator of layer normalization for numerical stability.
571
572
573
574
575
576
577
578
579
580
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
581
        Used for initializing scale factors :math:`\gamma`.
582
583
584
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
585
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
586
    scale_axes : Tuple[str, ...], default = ('embed', )
587
588
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
589
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
590
591
592
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
593
594
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
595
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
596
597
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
598
599
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
600
    kernel_axes : Tuple[str, ...], default = ()
601
        The name of axes used to shard the weights with a corresponding mesh.
602
    use_bias: bool, default = False
603
604
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
605
    bias_init: Initializer, default = flax.linen.initializers.zeros
606
607
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
608
    bias_axes: Tuple[str, ...], default = ()
609
610
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
611
    return_layernorm_output: bool, default = False
612
        Indicate whether to return the output of layer normalization.
613
        If set False, return None as the second tensor in outputs.
614
    enable_low_rank_adaptation: bool, default = False
615
        Indicate whether to enable low rank adaptation for each dense layer.
616
617
618
619
620
621
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
622
    axis:  Union[Iterable[int], int], default = -1
623
        An integer tuple with axes to apply the transformation on.
624
625
626
627
628
629
630
631
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
632
633
634

    Optimization parameters
    -----------------------
635
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
636
        The data type used to allocate the initial parameters.
637
    depth_scaling: float, default = None
638
        The factor to scale the output from `DenseGeneral`. It should be a float
639
        value or None. When None is set, then no scaling is applied.
640
641
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
642
643
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
644
645
646
647
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
648
    layernorm_type: str = "layernorm"
649
    epsilon: float = 1e-6
650
651
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
652
    scale_axes: Tuple[str, ...] = ("embed",)
653
    ln_bias_init: Initializer = nn.initializers.zeros
654
    ln_bias_axes: Tuple[str, ...] = ("embed",)
655
656
657
658
659
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
660
    return_layernorm_output: bool = False
661
662
663
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
664
665
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
666
667
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
668
    depth_scaling: float = None
669
    transpose_batch_sequence: bool = False
670
    quantization_checkpoint_name: Optional[str] = None
671
672
673

    def __post_init__(self):
        if self.kernel_init is None:
674
            self.kernel_init = nn.initializers.variance_scaling(
675
676
677
                1.0,
                "fan_in",
                "truncated_normal",
678
                dtype=self.dtype,
679
            )
680
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
681
682
            self.scale_init,
            self.zero_centered_gamma,
683
        )
684
        self.quantizer_set = QuantizerFactory.create_set()
685
686
687
688
689
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
690
        Apply layer normalization to the input followed by a dense layer transformation.
691
692
693
694
695
696
697
698
699
700
701
702

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
703
            If :attr:`return_layernorm_output=False`, then this would be None.
704
        """
705
        assert self.axis == -1, "Only support axis = =-1 at this moment"
706

707
        input_dtype = inputs.dtype
708
709
        ln_output = None

710
711
712
        quantizer_set = self.generate_quantizer_set(
            quantization_checkpoint_name=self.quantization_checkpoint_name
        )
713

714
        fuse_layernorm = (
715
            get_quantize_config().is_fp8_enabled()
716
717
718
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
719
720

        if self.enable_layernorm:
721
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
722
            features = inputs.shape[-1]
723
            scale, ln_bias = _create_layernorm_parameters(
724
                self,
725
726
727
728
729
730
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
731
                input_dtype,
732
733
                self.dtype,
            )
734
735

            if not fuse_layernorm:
736
737
738
739
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
740
                    norm_type=self.layernorm_type,
741
742
743
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

759
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
760
761
762
763
764
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
765
        )
766
        if not get_quantize_config().is_fp8_enabled():
767
            kernel = kernel.astype(input_dtype)
768
769
770

        contract_ind = tuple(range(0, len(axis)))

771
        if fuse_layernorm:
772
            z = layernorm_dense(
773
774
775
776
                y,
                kernel,
                scale,
                ln_bias,
777
                norm_type=self.layernorm_type,
778
779
780
781
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
782
                kernel_axes=self.kernel_axes,
783
                quantizer_set=quantizer_set,
784
                transpose_batch_sequence=self.transpose_batch_sequence,
785
            )
786
        else:
787
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
788
789
790
791
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
792
                transpose_batch_sequence=self.transpose_batch_sequence,
793
794
795
796
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
797

798
        if self.enable_low_rank_adaptation:
799
800
801
802
803
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
804
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
805
            lora_a_kernel = self.param(
806
                "lora_a_kernel",
807
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
808
                lora_a_kernel_shape,
809
                self.dtype,
810
            ).astype(input_dtype)
811
812
813

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
814
            lora_b_kernel = self.param(
815
                "lora_b_kernel",
816
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
817
                lora_b_kernel_shape,
818
                self.dtype,
819
            ).astype(input_dtype)
820

821
822
823
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
824

825
826
        bias = None
        if self.use_bias:
827
828
829
830
831
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
832
            ).astype(input_dtype)
833
834

        if bias is not None:
835
836
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
837
838
839
840

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

841
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
842
        # z = z.reshape(*inputs.shape[: self.axis], *features)
843
        return z, ln_output  # dense_output, layer_norm_output
844
845
846
847
848


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
849
    consisting of 2 successive dense layer transformations, separated by given activations.
850
851
852
853

    Parameters
    ----------
    intermediate_dim: int, default = 2048
854
        Intermediate size to which input samples are projected.
855
    enable_layernorm: bool, default = True
856
        Indicate whether to enable layer normalization before dense layer transformation.
857
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
858
        Indicate the type of layer normalization.
859
    epsilon : float, default = 1e-6
860
        A value added to the denominator of layer normalization for numerical stability.
861
862
863
864
865
866
867
868
869
870
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
871
        Used for initializing scale factors :math:`\gamma`.
872
873
874
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
875
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
876
    scale_axes : Tuple[str, ...], default = ('embed', )
877
878
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
879
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
880
881
882
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
883
884
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
885
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
886
887
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
888
        Used for initializing the weights of both dense layer transformations.
889
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
890
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
891
        The name of axes used to shard the weights with a corresponding mesh for
892
        the weight of the first dense layer transformation.
893
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
894
        The name of axes used to shard the weights with a corresponding mesh for
895
        the weight of the second dense layer transformation.
896
    use_bias: bool, default = False
897
898
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
899
    bias_init: Initializer, default = flax.linen.initializers.zeros
900
901
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
902
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
903
        The name of axes used to shard bias with a corresponding mesh  for
904
        the weight of the first dense layer transformation.
905
        Only used when :attr:`use_bias=True`.
906
    bias_axes_2: Tuple[str, ...], default = ('embed',)
907
        The name of axes used to shard bias with a corresponding mesh  for
908
        the weight of the second dense layer transformation.
909
        Only used when :attr:`use_bias=True`.
910
    return_layernorm_output: bool, default = False
911
        Indicate whether to return the output of layer normalization.
912
        If set False, return None as the second tensor in outputs.
913
    activations: Sequence[Union[str, Callable]], default = ('gelu',)
914
        The sequence of activation functions to apply after the first dense layer transformation.
915
        Each activation has its own transformation layer.
916
917
918
919
    activation_params: dict, default = None
        The parameters needed(if any) by the activation functions specified in :attr:`activations`.
        At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
        need additional parameters.
920
921
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
922
    intermediate_dropout_rate: float, default = 0.0
923
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
924
925
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
926
    enable_low_rank_adaptation: bool, default = False
927
        Indicate whether to enable low rank adaptation for each dense layer.
928
929
930
931
932
933
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
934
    axis:  Union[Iterable[int], int], default = -1
935
        An integer tuple with axes to apply the transformation on.
936
937
938
939
940
941
942
943
944
945
946
947
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
948
949
950
951
952
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

953
954
955

    Optimization parameters
    -----------------------
956
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
957
        The data type used to allocate the initial parameters.
958
959
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
960
961
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
962
963
964
965
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
966
    layernorm_type: str = "layernorm"
967
    epsilon: float = 1e-6
968
969
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
970
    scale_axes: Tuple[str, ...] = ("embed",)
971
    ln_bias_init: Initializer = nn.initializers.zeros
972
    ln_bias_axes: Tuple[str, ...] = ("embed",)
973
    kernel_init: Initializer = None
974
975
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
976
977
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
978
979
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
980
981
    return_layernorm_output: bool = False
    activations: Sequence[Union[str, Callable]] = ("gelu",)
982
    activation_params: dict = None
983
    intermediate_dropout_rng_name: str = "dropout"
984
    intermediate_dropout_rate: float = 0.0
Ming-Xu Huang's avatar
Ming-Xu Huang committed
985
    intermediate_hidden_dropout_dims: Sequence[int] = ()
986
987
988
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
989
990
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
991
992
993
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
994
995
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
996
    transpose_batch_sequence: bool = False
997
    quantization_checkpoint_name: Optional[str] = None
998
999
1000

    def __post_init__(self):
        if self.kernel_init is None:
1001
            self.kernel_init = nn.initializers.variance_scaling(
1002
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1003
            )
1004
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
1005
1006
            self.scale_init,
            self.zero_centered_gamma,
1007
        )
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1028
            If :attr:`return_layernorm_output=False`, then this would be None.
1029
        """
1030
1031
        assert self.axis == -1, "Only support axis == -1 at this moment"

1032
1033
1034
1035
1036
1037
        ffn1_quantizer_set = self.generate_quantizer_set(
            "_0", quantization_checkpoint_name=self.quantization_checkpoint_name
        )
        ffn2_quantizer_set = self.generate_quantizer_set(
            "_1", quantization_checkpoint_name=self.quantization_checkpoint_name
        )
1038

1039
        input_dtype = inputs.dtype
1040
1041
        ln_output = None

1042
1043
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1044
        fuse_layernorm = (
1045
            get_quantize_config().is_fp8_enabled()
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
1056
            ("clamped_silu", "clamped_linear"),
1057
1058
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1059
        normalized_acts = []
1060
1061
1062
        for act in self.activations:
            if not isinstance(act, str):
                return False
1063
            normalized_acts.append(act.lower())
1064
        normalized_acts = tuple(
1065
1066
1067
            reversed(normalized_acts)
            if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
            else normalized_acts
1068
        )
1069

1070
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1071

1072
1073
1074
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1075
1076
        # LayerNorm
        if self.enable_layernorm:
1077
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1078

1079
1080
            features = inputs.shape[-1]

1081
            scale, ln_bias = _create_layernorm_parameters(
1082
                self,
1083
1084
1085
1086
1087
1088
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1089
                input_dtype,
1090
1091
                self.dtype,
            )
1092
1093

            if not fuse_layernorm:
1094
1095
1096
1097
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1098
                    norm_type=self.layernorm_type,
1099
1100
1101
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1116
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1117

1118
        num_activations = len(normalized_acts)
1119
1120
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1121
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1122
        kernel_1 = self.param(
1123
            "wi_kernel",
1124
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1125
1126
1127
            num_activations,
            -2,
            kernel_1_each_shape,
1128
            self.dtype,
1129
        )
1130

1131
        if not get_quantize_config().is_fp8_enabled():
1132
            kernel_1 = kernel_1.astype(input_dtype)
1133

1134
1135
1136
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1137
        kernel_2 = self.param(
1138
            "wo_kernel",
1139
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1140
            kernel_2_shape,
1141
            self.dtype,
1142
        )
1143
        if not get_quantize_config().is_fp8_enabled():
1144
            kernel_2 = kernel_2.astype(input_dtype)
1145

1146
        contract_ind = tuple(range(0, len(axis)))
1147

1148
        if self.use_bias:
1149
            bias_1_shape = (num_activations, self.intermediate_dim)
1150
            bias_1 = self.param(
1151
                "wi_bias",
1152
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1153
1154
                bias_1_shape,
                self.dtype,
1155
            ).astype(input_dtype)
1156
1157

            bias_2_shape = (hidden_size,)
1158
            bias_2 = self.param(
1159
                "wo_bias",
1160
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1161
1162
                bias_2_shape,
                self.dtype,
1163
            ).astype(input_dtype)
1164
1165
1166
1167
        else:
            bias_1 = None
            bias_2 = None

1168
        if use_fused_layernorm_mlp:
1169
            out = layernorm_mlp(
1170
1171
1172
1173
1174
1175
1176
1177
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1178
                norm_input_axes=self.layernorm_input_axes,
1179
1180
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1181
1182
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1183
1184
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1185
                activation_type=normalized_acts,
1186
                activation_params=self.activation_params,
1187
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1188
                transpose_batch_sequence=self.transpose_batch_sequence,
1189
            )
1190
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1191
1192

        else:  # not use_fused_ln_geglu_mlp
1193
1194
            # DenseGeneral 1
            if fuse_layernorm:
1195
                x = layernorm_dense(
1196
1197
1198
1199
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1200
                    norm_type=self.layernorm_type,
1201
1202
1203
1204
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1205
                    kernel_axes=self.kernel_axes_1,
1206
                    quantizer_set=ffn1_quantizer_set,
1207
                    transpose_batch_sequence=self.transpose_batch_sequence,
1208
                )
1209
            else:
1210
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1211
1212
1213
1214
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1215
1216
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1217
                    quantizer_set=ffn1_quantizer_set,
1218
                    transpose_batch_sequence=self.transpose_batch_sequence,
1219
                )
1220

1221
            if self.enable_low_rank_adaptation:
1222
1223
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1224
1225
                    self.low_rank_adaptation_dim,
                )
1226
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1227
                wi_lora_a_kernel = self.param(
1228
                    "wi_lora_a_kernel",
1229
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1230
                    num_activations,
1231
1232
                    -2,
                    wi_lora_a_kernel_each_shape,
1233
                    self.dtype,
1234
                ).astype(input_dtype)
1235

1236
1237
1238
1239
1240
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1241
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1242
                wi_lora_b_kernel = self.param(
1243
                    "wi_lora_b_kernel",
1244
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1245
                    wi_lora_b_kernel_shape,
1246
                    self.dtype,
1247
                ).astype(input_dtype)
1248

1249
1250
1251
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1252
                    (num_activations, self.intermediate_dim),
1253
1254
1255
1256
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1257

1258
            if self.use_bias:
1259
                x += jnp.reshape(bias_1, bias_1_shape)
1260

1261
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1262
            if is_act_implemented:
1263
                z = activation(x, normalized_acts)
1264
            else:
1265
                activations = []
1266
                x = jnp.split(x, num_activations, axis=-2)
1267
                for idx, act_fn in enumerate(normalized_acts):
1268
1269
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1270
                z = reduce(operator.mul, activations)
1271
                z = jnp.squeeze(z, axis=-2)
1272
            z = z.astype(input_dtype)
1273

1274
1275
1276
1277
1278
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1279

1280
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1281
            z = z.astype(input_dtype)
1282

1283
            # DenseGeneral 2
1284
            out = dense(
1285
1286
1287
1288
1289
1290
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1291
                transpose_batch_sequence=self.transpose_batch_sequence,
1292
            )
1293

1294
1295
1296
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1297
                wo_lora_a_kernel = self.param(
1298
                    "wo_lora_a_kernel",
1299
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1300
                    wo_lora_a_kernel_shape,
1301
                    self.dtype,
1302
                ).astype(input_dtype)
1303
1304
1305

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1306
                wo_lora_b_kernel = self.param(
1307
                    "wo_lora_b_kernel",
1308
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1309
                    wo_lora_b_kernel_shape,
1310
                    self.dtype,
1311
                ).astype(input_dtype)
1312

1313
1314
1315
1316
1317
1318
1319
1320
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1321

1322
            if self.use_bias:
1323
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1324

1325
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1326

1327
        assert out.dtype == input_dtype
1328
        return out, ln_output  # Output, layer_norm_output