module.py 51.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

18

19
from ..dense import dense
20
21
22

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
23
24
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
25
from ..activation import activation
26
from ..softmax import softmax, SoftmaxType
27
from ..sharding import with_sharding_constraint_by_logical_axes
28
29
30
31
32
33
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
34
35
from ..quantize import (
    QuantizerFactory,
36
    get_global_quantize_recipe,
37
38
    QuantizeMetaSet,
    TensorSource,
39
    get_quantize_config_with_recipe,
40
    noop_quantizer_set,
41
)
42
43
44

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
45
46
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
47
48
49
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
50
51
52
53
54
55
56
57
58
59
60
61
62
63
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


64
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
65
66
67
68
69
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
70
71
72
    return nn.initializers.zeros


73
def _create_layernorm_parameters(
74
    module,
75
76
77
78
79
80
81
82
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
83
):
84
85
86
87
88
89
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
90

91
92
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
93
94
95
96
97
98
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
99
    else:
100
        assert norm_type == "rmsnorm"
101
102
103
104
105
106
107
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
108
    if fn_or_string == "linear":
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
123
124
125
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
126
127
128
129
130
131
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


132
133
134
135
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
136
    hidden_in_names = "ijklm"[: len(axis)]
137
    assert len(features) <= 5
138
139
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
140
141
142
143
144
145
146
147
148

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
149
150
151
152
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
153
154
155
156
157
158

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


159
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
160
161
    r"""
    Applies softmax over a mini-batch of inputs.
162
163
164
165
166
167
168
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
169
170
171
172

    Parameters
    ----------
    scale_factor : float, default = 1.0
173
174
175
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
176
177
178
179
180
181
182
183
184
185
186
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
187
        input_dtype = inputs.dtype
188
189
        logits = inputs

190
191
        # use primitives
        if is_softmax_kernel_available(
192
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
193
        ):
194
            if bias is not None:
195
                logits = logits + bias.astype(input_dtype)
196
197
198
199
200

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

201
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
202
        # use default jax based implementation
203
204
        else:
            if bias is not None:
205
                logits = logits + bias.astype(input_dtype)
206

207
208
209
210
211
212
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
213
            else:
214
215
216
217
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
218
        assert input_dtype == outputs.dtype
219
220
221
        return outputs


222
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
252
        A value added to the denominator of layer normalization for numerical stability.
253
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
254
        Indicate the type of layer normalization.
255
256
257
258
259
260
261
262
263
264
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
265
        Used for initializing scale factors :math:`\gamma`.
266
267
268
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
269
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
270
    scale_axes : Tuple[str, ...], default = ('embed', )
271
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
272
    bias_init : Initializer, default = flax.linen.initializers.zeros
273
274
275
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
276
277
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
278
        only used when :attr:`layernorm_type='layernorm'`.
279
280
281

    Optimization parameters
    -----------------------
282
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
283
        The data type used to allocate the initial parameters.
284
    """
285

286
    epsilon: float = 1e-6
287
    layernorm_type: str = "layernorm"
288
289
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
290
    scale_axes: Tuple[str, ...] = ("embed",)
291
    bias_init: Initializer = nn.initializers.zeros
292
    bias_axes: Tuple[str, ...] = ("embed",)
293
294
    dtype: DType = jnp.float32

295
    def __post_init__(self):
296
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
297
298
            self.scale_init,
            self.zero_centered_gamma,
299
        )
300
301
        super().__post_init__()

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
317
        input_dtype = x.dtype
318
319

        features = x.shape[-1]
320
        scale, ln_bias = _create_layernorm_parameters(
321
            self,
322
323
324
325
326
327
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
328
            input_dtype,
329
330
            self.dtype,
        )
331
        out = layernorm(
332
333
334
            x,
            scale,
            ln_bias,
335
            norm_type=self.layernorm_type,
336
337
338
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
339
340
        assert out.dtype == input_dtype
        return out
341
342
343


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
344
345
346
347
    """
    Base class of transformer engine
    """

348
    def generate_quantizer_set(
349
350
351
352
353
        self,
        postfix: str = "",
        variable_collection: str = None,
        quantization_checkpoint_name: Optional[str] = None,
        fp8_recipe=None,
354
    ):
355
        """
356
        Generate a set of FP8 meta for a GEMM.
357
358
        """

359
360
361
362
363
        if fp8_recipe is None:
            fp8_recipe = get_global_quantize_recipe()

        quantize_config = get_quantize_config_with_recipe(fp8_recipe)

364
365
366
        collection_name = (
            variable_collection
            if variable_collection is not None
367
            else quantize_config.COLLECTION_NAME
368
369
370
371
372
373
374
375
376
377
378
        )

        x_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.X, "x"
        )
        kernel_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.KERNEL, "kernel"
        )
        grad_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.DGRAD, "grad"
        )
379

380
381
382
        quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)

        quantizer_set = QuantizerFactory.create_set(
383
384
385
            fp8_recipe=fp8_recipe,
            quantize_meta_set=quantize_meta_set,
            checkpoint_name=quantization_checkpoint_name,
386
        )
387
        return quantizer_set
388
389
390


class DenseGeneral(TransformerEngineBase):
391
    r"""
392
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
393
394
395
396

    Parameters
    ----------
    features : Union[Iterable[int], int]
397
        The hidden size of each output sample.
398
399
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
400
401
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
402
    kernel_axes : Tuple[str, ...], default = ()
403
        The name of axes used to shard the weights with a corresponding mesh.
404
    use_bias: bool, default = False
405
406
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
407
    bias_init: Initializer, default = flax.linen.initializers.zeros
408
409
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
410
    bias_axes: Tuple[str, ...], default = ()
411
412
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
413
    enable_low_rank_adaptation: bool, default = False
414
        Indicate whether to enable low rank adaptation for each dense layer.
415
416
417
418
419
420
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
421
    axis:  Union[Iterable[int], int], default = -1
422
        An integer tuple with axes to apply the transformation on.
423
424
425
426
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
427
428
429

    Optimization parameters
    -----------------------
430
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
431
        The data type used to allocate the initial parameters.
432
433
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
434
435
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
436
437
438
439
440
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
441
    use_bias: bool = True
442
443
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
444
445
446
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
447
448
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
449
    input_axes: Tuple[str, ...] = ()
450
    transpose_batch_sequence: bool = False
451
    quantization_checkpoint_name: Optional[str] = None
452
453
454

    def __post_init__(self):
        if self.kernel_init is None:
455
            self.kernel_init = nn.initializers.variance_scaling(
456
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
457
            )
458
459
460
461
462
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
463
        Apply the dense layer transformation to the input.
464
465
466
467
468
469
470
471
472
473
474

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
475

476
        input_dtype = inputs.dtype
477
478
479
480
481
482
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
483
484
485
486
487
488

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
489
490
491
492
493
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
494
        )
495

496
497
498
499
500
        quantizer_set = self.generate_quantizer_set(
            quantization_checkpoint_name=self.quantization_checkpoint_name
        )

        if quantizer_set == noop_quantizer_set:
501
            kernel = kernel.astype(input_dtype)
502
503

        if self.use_bias:
504
505
506
507
508
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
509
            ).astype(input_dtype)
510
511
512
513
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
514
        y = dense(
515
516
517
518
519
520
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
521
            transpose_batch_sequence=self.transpose_batch_sequence,
522
        )
523

524
        if self.enable_low_rank_adaptation:
525
526
527
528
529
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
530
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
531
            lora_a_kernel = self.param(
532
                "lora_a_kernel",
533
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
534
                lora_a_kernel_shape,
535
                self.dtype,
536
            ).astype(input_dtype)
537
538
539

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
540
            lora_b_kernel = self.param(
541
                "lora_b_kernel",
542
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
543
                lora_b_kernel_shape,
544
                self.dtype,
545
            ).astype(input_dtype)
546

547
548
549
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
550

551
        if bias is not None:
552
553
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
554
555

        assert y.dtype == input_dtype
556
557
558
559
560
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
561
    Applies layer normalization followed by dense layer transformation to the incoming data.
562
563
564
565

    Parameters
    ----------
    features : Union[Iterable[int], int]
566
        The hidden size of each output sample.
567
    enable_layernorm: bool, default = True
568
        Indicate whether to enable layer normalization before dense layer transformation.
569
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
570
        Indicate the type of layer normalization.
571
    epsilon : float, default = 1e-6
572
        A value added to the denominator of layer normalization for numerical stability.
573
574
575
576
577
578
579
580
581
582
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
583
        Used for initializing scale factors :math:`\gamma`.
584
585
586
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
587
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
588
    scale_axes : Tuple[str, ...], default = ('embed', )
589
590
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
591
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
592
593
594
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
595
596
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
597
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
598
599
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
600
601
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
602
    kernel_axes : Tuple[str, ...], default = ()
603
        The name of axes used to shard the weights with a corresponding mesh.
604
    use_bias: bool, default = False
605
606
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
607
    bias_init: Initializer, default = flax.linen.initializers.zeros
608
609
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
610
    bias_axes: Tuple[str, ...], default = ()
611
612
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
613
    return_layernorm_output: bool, default = False
614
        Indicate whether to return the output of layer normalization.
615
        If set False, return None as the second tensor in outputs.
616
    enable_low_rank_adaptation: bool, default = False
617
        Indicate whether to enable low rank adaptation for each dense layer.
618
619
620
621
622
623
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
624
    axis:  Union[Iterable[int], int], default = -1
625
        An integer tuple with axes to apply the transformation on.
626
627
628
629
630
631
632
633
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
634
635
636

    Optimization parameters
    -----------------------
637
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
638
        The data type used to allocate the initial parameters.
639
    depth_scaling: float, default = None
640
        The factor to scale the output from `DenseGeneral`. It should be a float
641
        value or None. When None is set, then no scaling is applied.
642
643
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
644
645
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
646
647
648
649
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
650
    layernorm_type: str = "layernorm"
651
    epsilon: float = 1e-6
652
653
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
654
    scale_axes: Tuple[str, ...] = ("embed",)
655
    ln_bias_init: Initializer = nn.initializers.zeros
656
    ln_bias_axes: Tuple[str, ...] = ("embed",)
657
658
659
660
661
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
662
    return_layernorm_output: bool = False
663
664
665
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
666
667
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
668
669
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
670
    depth_scaling: float = None
671
    transpose_batch_sequence: bool = False
672
    quantization_checkpoint_name: Optional[str] = None
673
674
675

    def __post_init__(self):
        if self.kernel_init is None:
676
            self.kernel_init = nn.initializers.variance_scaling(
677
678
679
                1.0,
                "fan_in",
                "truncated_normal",
680
                dtype=self.dtype,
681
            )
682
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
683
684
            self.scale_init,
            self.zero_centered_gamma,
685
        )
686
        self.quantizer_set = QuantizerFactory.create_set()
687
688
689
690
691
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
692
        Apply layer normalization to the input followed by a dense layer transformation.
693
694
695
696
697
698
699
700
701
702
703
704

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
705
            If :attr:`return_layernorm_output=False`, then this would be None.
706
        """
707
        assert self.axis == -1, "Only support axis = =-1 at this moment"
708

709
        input_dtype = inputs.dtype
710
711
        ln_output = None

712
713
714
        quantizer_set = self.generate_quantizer_set(
            quantization_checkpoint_name=self.quantization_checkpoint_name
        )
715

716
        fuse_layernorm = (
717
            quantizer_set != noop_quantizer_set
718
719
720
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
721
722

        if self.enable_layernorm:
723
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
724
            features = inputs.shape[-1]
725
            scale, ln_bias = _create_layernorm_parameters(
726
                self,
727
728
729
730
731
732
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
733
                input_dtype,
734
735
                self.dtype,
            )
736
737

            if not fuse_layernorm:
738
739
740
741
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
742
                    norm_type=self.layernorm_type,
743
744
745
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

761
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
762
763
764
765
766
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
767
        )
768
        if quantizer_set == noop_quantizer_set:
769
            kernel = kernel.astype(input_dtype)
770
771
772

        contract_ind = tuple(range(0, len(axis)))

773
        if fuse_layernorm:
774
            z = layernorm_dense(
775
776
777
778
                y,
                kernel,
                scale,
                ln_bias,
779
                norm_type=self.layernorm_type,
780
781
782
783
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
784
                kernel_axes=self.kernel_axes,
785
                quantizer_set=quantizer_set,
786
                transpose_batch_sequence=self.transpose_batch_sequence,
787
            )
788
        else:
789
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
790
791
792
793
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
794
                transpose_batch_sequence=self.transpose_batch_sequence,
795
796
797
798
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
799

800
        if self.enable_low_rank_adaptation:
801
802
803
804
805
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
806
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
807
            lora_a_kernel = self.param(
808
                "lora_a_kernel",
809
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
810
                lora_a_kernel_shape,
811
                self.dtype,
812
            ).astype(input_dtype)
813
814
815

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
816
            lora_b_kernel = self.param(
817
                "lora_b_kernel",
818
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
819
                lora_b_kernel_shape,
820
                self.dtype,
821
            ).astype(input_dtype)
822

823
824
825
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
826

827
828
        bias = None
        if self.use_bias:
829
830
831
832
833
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
834
            ).astype(input_dtype)
835
836

        if bias is not None:
837
838
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
839
840
841
842

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

843
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
844
        # z = z.reshape(*inputs.shape[: self.axis], *features)
845
        return z, ln_output  # dense_output, layer_norm_output
846
847
848
849
850


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
851
    consisting of 2 successive dense layer transformations, separated by given activations.
852
853
854
855

    Parameters
    ----------
    intermediate_dim: int, default = 2048
856
        Intermediate size to which input samples are projected.
857
    enable_layernorm: bool, default = True
858
        Indicate whether to enable layer normalization before dense layer transformation.
859
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
860
        Indicate the type of layer normalization.
861
    epsilon : float, default = 1e-6
862
        A value added to the denominator of layer normalization for numerical stability.
863
864
865
866
867
868
869
870
871
872
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
873
        Used for initializing scale factors :math:`\gamma`.
874
875
876
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
877
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
878
    scale_axes : Tuple[str, ...], default = ('embed', )
879
880
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
881
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
882
883
884
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
885
886
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
887
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
888
889
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
890
        Used for initializing the weights of both dense layer transformations.
891
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
892
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
893
        The name of axes used to shard the weights with a corresponding mesh for
894
        the weight of the first dense layer transformation.
895
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
896
        The name of axes used to shard the weights with a corresponding mesh for
897
        the weight of the second dense layer transformation.
898
    use_bias: bool, default = False
899
900
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
901
    bias_init: Initializer, default = flax.linen.initializers.zeros
902
903
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
904
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
905
        The name of axes used to shard bias with a corresponding mesh  for
906
        the weight of the first dense layer transformation.
907
        Only used when :attr:`use_bias=True`.
908
    bias_axes_2: Tuple[str, ...], default = ('embed',)
909
        The name of axes used to shard bias with a corresponding mesh  for
910
        the weight of the second dense layer transformation.
911
        Only used when :attr:`use_bias=True`.
912
    return_layernorm_output: bool, default = False
913
        Indicate whether to return the output of layer normalization.
914
        If set False, return None as the second tensor in outputs.
915
    activations: Sequence[Union[str, Callable]], default = ('gelu',)
916
        The sequence of activation functions to apply after the first dense layer transformation.
917
        Each activation has its own transformation layer.
918
919
920
921
    activation_params: dict, default = None
        The parameters needed(if any) by the activation functions specified in :attr:`activations`.
        At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
        need additional parameters.
922
923
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
924
    intermediate_dropout_rate: float, default = 0.0
925
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
926
927
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
928
    enable_low_rank_adaptation: bool, default = False
929
        Indicate whether to enable low rank adaptation for each dense layer.
930
931
932
933
934
935
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
936
    axis:  Union[Iterable[int], int], default = -1
937
        An integer tuple with axes to apply the transformation on.
938
939
940
941
942
943
944
945
946
947
948
949
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
950
951
952
953
954
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

955
956
957

    Optimization parameters
    -----------------------
958
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
959
        The data type used to allocate the initial parameters.
960
961
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
962
963
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
964
965
966
967
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
968
    layernorm_type: str = "layernorm"
969
    epsilon: float = 1e-6
970
971
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
972
    scale_axes: Tuple[str, ...] = ("embed",)
973
    ln_bias_init: Initializer = nn.initializers.zeros
974
    ln_bias_axes: Tuple[str, ...] = ("embed",)
975
    kernel_init: Initializer = None
976
977
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
978
979
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
980
981
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
982
983
    return_layernorm_output: bool = False
    activations: Sequence[Union[str, Callable]] = ("gelu",)
984
    activation_params: dict = None
985
    intermediate_dropout_rng_name: str = "dropout"
986
    intermediate_dropout_rate: float = 0.0
Ming-Xu Huang's avatar
Ming-Xu Huang committed
987
    intermediate_hidden_dropout_dims: Sequence[int] = ()
988
989
990
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
991
992
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
993
994
995
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
996
997
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
998
    transpose_batch_sequence: bool = False
999
    quantization_checkpoint_name: Optional[str] = None
1000
1001
1002

    def __post_init__(self):
        if self.kernel_init is None:
1003
            self.kernel_init = nn.initializers.variance_scaling(
1004
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1005
            )
1006
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
1007
1008
            self.scale_init,
            self.zero_centered_gamma,
1009
        )
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1030
            If :attr:`return_layernorm_output=False`, then this would be None.
1031
        """
1032
1033
        assert self.axis == -1, "Only support axis == -1 at this moment"

1034
1035
1036
1037
1038
1039
        ffn1_quantizer_set = self.generate_quantizer_set(
            "_0", quantization_checkpoint_name=self.quantization_checkpoint_name
        )
        ffn2_quantizer_set = self.generate_quantizer_set(
            "_1", quantization_checkpoint_name=self.quantization_checkpoint_name
        )
1040

1041
        input_dtype = inputs.dtype
1042
1043
        ln_output = None

1044
1045
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1046
        fuse_layernorm = (
1047
            ffn1_quantizer_set != noop_quantizer_set
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
1058
            ("clamped_silu", "clamped_linear"),
1059
1060
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1061
        normalized_acts = []
1062
1063
1064
        for act in self.activations:
            if not isinstance(act, str):
                return False
1065
            normalized_acts.append(act.lower())
1066
        normalized_acts = tuple(
1067
1068
1069
            reversed(normalized_acts)
            if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
            else normalized_acts
1070
        )
1071

1072
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1073

1074
1075
1076
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1077
1078
        # LayerNorm
        if self.enable_layernorm:
1079
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1080

1081
1082
            features = inputs.shape[-1]

1083
            scale, ln_bias = _create_layernorm_parameters(
1084
                self,
1085
1086
1087
1088
1089
1090
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1091
                input_dtype,
1092
1093
                self.dtype,
            )
1094
1095

            if not fuse_layernorm:
1096
1097
1098
1099
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1100
                    norm_type=self.layernorm_type,
1101
1102
1103
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1118
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1119

1120
        num_activations = len(normalized_acts)
1121
1122
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1123
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1124
        kernel_1 = self.param(
1125
            "wi_kernel",
1126
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1127
1128
1129
            num_activations,
            -2,
            kernel_1_each_shape,
1130
            self.dtype,
1131
        )
1132

1133
        if ffn1_quantizer_set == noop_quantizer_set:
1134
            kernel_1 = kernel_1.astype(input_dtype)
1135

1136
1137
1138
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1139
        kernel_2 = self.param(
1140
            "wo_kernel",
1141
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1142
            kernel_2_shape,
1143
            self.dtype,
1144
        )
1145
        if ffn2_quantizer_set == noop_quantizer_set:
1146
            kernel_2 = kernel_2.astype(input_dtype)
1147

1148
        contract_ind = tuple(range(0, len(axis)))
1149

1150
        if self.use_bias:
1151
            bias_1_shape = (num_activations, self.intermediate_dim)
1152
            bias_1 = self.param(
1153
                "wi_bias",
1154
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1155
1156
                bias_1_shape,
                self.dtype,
1157
            ).astype(input_dtype)
1158
1159

            bias_2_shape = (hidden_size,)
1160
            bias_2 = self.param(
1161
                "wo_bias",
1162
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1163
1164
                bias_2_shape,
                self.dtype,
1165
            ).astype(input_dtype)
1166
1167
1168
1169
        else:
            bias_1 = None
            bias_2 = None

1170
        if use_fused_layernorm_mlp:
1171
            out = layernorm_mlp(
1172
1173
1174
1175
1176
1177
1178
1179
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1180
                norm_input_axes=self.layernorm_input_axes,
1181
1182
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1183
1184
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1185
1186
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1187
                activation_type=normalized_acts,
1188
                activation_params=self.activation_params,
1189
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1190
                transpose_batch_sequence=self.transpose_batch_sequence,
1191
            )
1192
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1193
1194

        else:  # not use_fused_ln_geglu_mlp
1195
1196
            # DenseGeneral 1
            if fuse_layernorm:
1197
                x = layernorm_dense(
1198
1199
1200
1201
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1202
                    norm_type=self.layernorm_type,
1203
1204
1205
1206
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1207
                    kernel_axes=self.kernel_axes_1,
1208
                    quantizer_set=ffn1_quantizer_set,
1209
                    transpose_batch_sequence=self.transpose_batch_sequence,
1210
                )
1211
            else:
1212
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1213
1214
1215
1216
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1217
1218
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1219
                    quantizer_set=ffn1_quantizer_set,
1220
                    transpose_batch_sequence=self.transpose_batch_sequence,
1221
                )
1222

1223
            if self.enable_low_rank_adaptation:
1224
1225
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1226
1227
                    self.low_rank_adaptation_dim,
                )
1228
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1229
                wi_lora_a_kernel = self.param(
1230
                    "wi_lora_a_kernel",
1231
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1232
                    num_activations,
1233
1234
                    -2,
                    wi_lora_a_kernel_each_shape,
1235
                    self.dtype,
1236
                ).astype(input_dtype)
1237

1238
1239
1240
1241
1242
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1243
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1244
                wi_lora_b_kernel = self.param(
1245
                    "wi_lora_b_kernel",
1246
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1247
                    wi_lora_b_kernel_shape,
1248
                    self.dtype,
1249
                ).astype(input_dtype)
1250

1251
1252
1253
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1254
                    (num_activations, self.intermediate_dim),
1255
1256
1257
1258
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1259

1260
            if self.use_bias:
1261
                x += jnp.reshape(bias_1, bias_1_shape)
1262

1263
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1264
            if is_act_implemented:
1265
                z = activation(x, normalized_acts)
1266
            else:
1267
                activations = []
1268
                x = jnp.split(x, num_activations, axis=-2)
1269
                for idx, act_fn in enumerate(normalized_acts):
1270
1271
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1272
                z = reduce(operator.mul, activations)
1273
                z = jnp.squeeze(z, axis=-2)
1274
            z = z.astype(input_dtype)
1275

1276
1277
1278
1279
1280
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1281

1282
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1283
            z = z.astype(input_dtype)
1284

1285
            # DenseGeneral 2
1286
            out = dense(
1287
1288
1289
1290
1291
1292
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1293
                transpose_batch_sequence=self.transpose_batch_sequence,
1294
            )
1295

1296
1297
1298
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1299
                wo_lora_a_kernel = self.param(
1300
                    "wo_lora_a_kernel",
1301
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1302
                    wo_lora_a_kernel_shape,
1303
                    self.dtype,
1304
                ).astype(input_dtype)
1305
1306
1307

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1308
                wo_lora_b_kernel = self.param(
1309
                    "wo_lora_b_kernel",
1310
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1311
                    wo_lora_b_kernel_shape,
1312
                    self.dtype,
1313
                ).astype(input_dtype)
1314

1315
1316
1317
1318
1319
1320
1321
1322
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1323

1324
            if self.use_bias:
1325
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1326

1327
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1328

1329
        assert out.dtype == input_dtype
1330
        return out, ln_output  # Output, layer_norm_output