module.py 52.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
10
import warnings
11
12

import numpy as np
13
import jax.numpy as jnp
14
15
16
from flax import linen as nn
from jax import lax
from jax import random as jax_random
17
from jax.ad_checkpoint import checkpoint_name
18

19

20
from ..dense import dense
21
22
23

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
24
25
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
26
from ..activation import activation
27
from ..softmax import softmax, SoftmaxFusionType
28
from ..sharding import with_sharding_constraint_by_logical_axes
29
from ..attention import AttnSoftmaxType
30
31
32
33
34
35
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
36
37
from ..quantize import (
    QuantizerFactory,
38
    get_global_quantize_recipe,
39
40
    QuantizeMetaSet,
    TensorSource,
41
    get_quantize_config_with_recipe,
42
    noop_quantizer_set,
43
)
44
45
46

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
47
48
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
49
50
51
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
52
53
54
55
56
57
58
59
60
61
62
63
64
65
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


66
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
67
68
69
70
71
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
72
73
74
    return nn.initializers.zeros


75
def _create_layernorm_parameters(
76
    module,
77
78
79
80
81
82
83
84
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
85
):
86
87
88
89
90
91
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
92

93
94
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
95
96
97
98
99
100
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
101
    else:
102
        assert norm_type == "rmsnorm"
103
104
105
106
107
108
109
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
110
    if fn_or_string == "linear":
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
125
126
127
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
128
129
130
131
132
133
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


134
135
136
137
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
138
    hidden_in_names = "ijklm"[: len(axis)]
139
    assert len(features) <= 5
140
141
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
142
143
144
145
146
147
148
149
150

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
151
152
153
154
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
155
156
157
158
159
160

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


161
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
162
163
    r"""
    Applies softmax over a mini-batch of inputs.
164
165
166
167
168
169
170
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
171
172
173
174

    Parameters
    ----------
    scale_factor : float, default = 1.0
175
        Scalar for the input to softmax.
176
177
178
    softmax_fusion_type : SoftmaxFusionType, default = SoftmaxFusionType.SCALED
        Indicate the type of softmax.
    softmax_type : AttnSoftmaxType, default = AttnSoftmaxType.VANILLA_SOFTMAX
179
        Indicate the type of softmax.
180
181
182
    """

    scale_factor: float = 1.0
183
184
    softmax_fusion_type: SoftmaxFusionType = SoftmaxFusionType.SCALED
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
185
186

    @nn.compact
187
188
189
    def __call__(
        self, inputs: Array, mask: Array = None, bias: Array = None, softmax_offset: Array = None
    ) -> jnp.ndarray:
190
191
192
193
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
194
        input_dtype = inputs.dtype
195
196
        logits = inputs

197
198
199
200
201
        if softmax_offset is not None:
            assert self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX
        if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            softmax_offset = 0.0

202
203
        # use primitives
        if is_softmax_kernel_available(
204
205
206
207
208
209
210
            self.softmax_fusion_type,
            self.softmax_type,
            batch,
            heads,
            q_seqlen,
            k_seqlen,
            input_dtype,
211
        ):
212
            if bias is not None:
213
                logits = logits + bias.astype(input_dtype)
214
215

            mask_ = mask
216
            if self.softmax_fusion_type is not SoftmaxFusionType.SCALED_MASKED:
217
218
                mask_ = None

219
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_fusion_type)
220
        # use default jax based implementation
221
        else:
222
223
224
225
226
227
            warnings.warn(
                "Using unfused JAX softmax implementation instead of TE fused primitives. ",
                UserWarning,
                stacklevel=2,
            )

228
            if bias is not None:
229
                logits = logits + bias.astype(input_dtype)
230

231
232
233
234
235
236
237
238
            if self.softmax_fusion_type is SoftmaxFusionType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor, softmax_offset)
            elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor, softmax_offset)
            elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(
                    logits, self.scale_factor, softmax_offset
                )
239
            else:
240
                raise ValueError(
241
242
                    f"Unsupported softmax fusion: {self.softmax_fusion_type}. softmax_fusion_type"
                    " must be [SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
243
                )
244
        assert input_dtype == outputs.dtype
245
246
247
        return outputs


248
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
278
        A value added to the denominator of layer normalization for numerical stability.
279
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
280
        Indicate the type of layer normalization.
281
282
283
284
285
286
287
288
289
290
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
291
        Used for initializing scale factors :math:`\gamma`.
292
293
294
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
295
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
296
    scale_axes : Tuple[str, ...], default = ('embed', )
297
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
298
    bias_init : Initializer, default = flax.linen.initializers.zeros
299
300
301
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
302
303
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
304
        only used when :attr:`layernorm_type='layernorm'`.
305
306
307

    Optimization parameters
    -----------------------
308
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
309
        The data type used to allocate the initial parameters.
310
    """
311

312
    epsilon: float = 1e-6
313
    layernorm_type: str = "layernorm"
314
315
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
316
    scale_axes: Tuple[str, ...] = ("embed",)
317
    bias_init: Initializer = nn.initializers.zeros
318
    bias_axes: Tuple[str, ...] = ("embed",)
319
320
    dtype: DType = jnp.float32

321
    def __post_init__(self):
322
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
323
324
            self.scale_init,
            self.zero_centered_gamma,
325
        )
326
327
        super().__post_init__()

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
343
        input_dtype = x.dtype
344
345

        features = x.shape[-1]
346
        scale, ln_bias = _create_layernorm_parameters(
347
            self,
348
349
350
351
352
353
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
354
            input_dtype,
355
356
            self.dtype,
        )
357
        out = layernorm(
358
359
360
            x,
            scale,
            ln_bias,
361
            norm_type=self.layernorm_type,
362
363
364
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
365
366
        assert out.dtype == input_dtype
        return out
367
368
369


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
370
371
372
373
    """
    Base class of transformer engine
    """

374
    def generate_quantizer_set(
375
376
377
378
379
        self,
        postfix: str = "",
        variable_collection: str = None,
        quantization_checkpoint_name: Optional[str] = None,
        fp8_recipe=None,
380
    ):
381
        """
382
        Generate a set of FP8 meta for a GEMM.
383
384
        """

385
386
387
388
389
        if fp8_recipe is None:
            fp8_recipe = get_global_quantize_recipe()

        quantize_config = get_quantize_config_with_recipe(fp8_recipe)

390
391
392
        collection_name = (
            variable_collection
            if variable_collection is not None
393
            else quantize_config.COLLECTION_NAME
394
395
396
397
398
399
400
401
402
403
404
        )

        x_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.X, "x"
        )
        kernel_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.KERNEL, "kernel"
        )
        grad_meta = quantize_config.get_quantize_flax_meta(
            self, collection_name, postfix, TensorSource.DGRAD, "grad"
        )
405

406
407
408
        quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)

        quantizer_set = QuantizerFactory.create_set(
409
410
411
            fp8_recipe=fp8_recipe,
            quantize_meta_set=quantize_meta_set,
            checkpoint_name=quantization_checkpoint_name,
412
        )
413
        return quantizer_set
414
415
416


class DenseGeneral(TransformerEngineBase):
417
    r"""
418
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
419
420
421
422

    Parameters
    ----------
    features : Union[Iterable[int], int]
423
        The hidden size of each output sample.
424
425
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
426
427
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
428
    kernel_axes : Tuple[str, ...], default = ()
429
        The name of axes used to shard the weights with a corresponding mesh.
430
    use_bias: bool, default = False
431
432
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
433
    bias_init: Initializer, default = flax.linen.initializers.zeros
434
435
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
436
    bias_axes: Tuple[str, ...], default = ()
437
438
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
439
    enable_low_rank_adaptation: bool, default = False
440
        Indicate whether to enable low rank adaptation for each dense layer.
441
442
443
444
445
446
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
447
    axis:  Union[Iterable[int], int], default = -1
448
        An integer tuple with axes to apply the transformation on.
449
450
451
452
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
453
454
455

    Optimization parameters
    -----------------------
456
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
457
        The data type used to allocate the initial parameters.
458
459
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
460
461
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
462
463
464
465
466
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
467
    use_bias: bool = True
468
469
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
470
471
472
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
473
474
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
475
    input_axes: Tuple[str, ...] = ()
476
    transpose_batch_sequence: bool = False
477
    quantization_checkpoint_name: Optional[str] = None
478
479
480

    def __post_init__(self):
        if self.kernel_init is None:
481
            self.kernel_init = nn.initializers.variance_scaling(
482
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
483
            )
484
485
486
487
488
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
489
        Apply the dense layer transformation to the input.
490
491
492
493
494
495
496
497
498
499
500

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
501

502
        input_dtype = inputs.dtype
503
504
505
506
507
508
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
509
510
511
512
513
514

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
515
516
517
518
519
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
520
        )
521

522
523
524
525
526
        quantizer_set = self.generate_quantizer_set(
            quantization_checkpoint_name=self.quantization_checkpoint_name
        )

        if quantizer_set == noop_quantizer_set:
527
            kernel = kernel.astype(input_dtype)
528
529

        if self.use_bias:
530
531
532
533
534
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
535
            ).astype(input_dtype)
536
537
538
539
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
540
        y = dense(
541
542
543
544
545
546
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
547
            transpose_batch_sequence=self.transpose_batch_sequence,
548
        )
549

550
        if self.enable_low_rank_adaptation:
551
552
553
554
555
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
556
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
557
            lora_a_kernel = self.param(
558
                "lora_a_kernel",
559
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
560
                lora_a_kernel_shape,
561
                self.dtype,
562
            ).astype(input_dtype)
563
564
565

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
566
            lora_b_kernel = self.param(
567
                "lora_b_kernel",
568
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
569
                lora_b_kernel_shape,
570
                self.dtype,
571
            ).astype(input_dtype)
572

573
574
575
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
576

577
        if bias is not None:
578
579
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
580
581

        assert y.dtype == input_dtype
582
583
584
585
586
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
587
    Applies layer normalization followed by dense layer transformation to the incoming data.
588
589
590
591

    Parameters
    ----------
    features : Union[Iterable[int], int]
592
        The hidden size of each output sample.
593
    enable_layernorm: bool, default = True
594
        Indicate whether to enable layer normalization before dense layer transformation.
595
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
596
        Indicate the type of layer normalization.
597
    epsilon : float, default = 1e-6
598
        A value added to the denominator of layer normalization for numerical stability.
599
600
601
602
603
604
605
606
607
608
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
609
        Used for initializing scale factors :math:`\gamma`.
610
611
612
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
613
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
614
    scale_axes : Tuple[str, ...], default = ('embed', )
615
616
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
617
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
618
619
620
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
621
622
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
623
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
624
625
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
626
627
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
628
    kernel_axes : Tuple[str, ...], default = ()
629
        The name of axes used to shard the weights with a corresponding mesh.
630
    use_bias: bool, default = False
631
632
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
633
    bias_init: Initializer, default = flax.linen.initializers.zeros
634
635
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
636
    bias_axes: Tuple[str, ...], default = ()
637
638
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
639
    return_layernorm_output: bool, default = False
640
        Indicate whether to return the output of layer normalization.
641
        If set False, return None as the second tensor in outputs.
642
    enable_low_rank_adaptation: bool, default = False
643
        Indicate whether to enable low rank adaptation for each dense layer.
644
645
646
647
648
649
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
650
    axis:  Union[Iterable[int], int], default = -1
651
        An integer tuple with axes to apply the transformation on.
652
653
654
655
656
657
658
659
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
660
661
662

    Optimization parameters
    -----------------------
663
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
664
        The data type used to allocate the initial parameters.
665
    depth_scaling: float, default = None
666
        The factor to scale the output from `DenseGeneral`. It should be a float
667
        value or None. When None is set, then no scaling is applied.
668
669
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
670
671
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
672
673
674
675
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
676
    layernorm_type: str = "layernorm"
677
    epsilon: float = 1e-6
678
679
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
680
    scale_axes: Tuple[str, ...] = ("embed",)
681
    ln_bias_init: Initializer = nn.initializers.zeros
682
    ln_bias_axes: Tuple[str, ...] = ("embed",)
683
684
685
686
687
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
688
    return_layernorm_output: bool = False
689
690
691
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
692
693
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
694
695
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
696
    depth_scaling: float = None
697
    transpose_batch_sequence: bool = False
698
    quantization_checkpoint_name: Optional[str] = None
699
700
701

    def __post_init__(self):
        if self.kernel_init is None:
702
            self.kernel_init = nn.initializers.variance_scaling(
703
704
705
                1.0,
                "fan_in",
                "truncated_normal",
706
                dtype=self.dtype,
707
            )
708
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
709
710
            self.scale_init,
            self.zero_centered_gamma,
711
        )
712
        self.quantizer_set = QuantizerFactory.create_set()
713
714
715
716
717
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
718
        Apply layer normalization to the input followed by a dense layer transformation.
719
720
721
722
723
724
725
726
727
728
729
730

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
731
            If :attr:`return_layernorm_output=False`, then this would be None.
732
        """
733
        assert self.axis == -1, "Only support axis = =-1 at this moment"
734

735
        input_dtype = inputs.dtype
736
737
        ln_output = None

738
739
740
        quantizer_set = self.generate_quantizer_set(
            quantization_checkpoint_name=self.quantization_checkpoint_name
        )
741

742
        fuse_layernorm = (
743
            quantizer_set != noop_quantizer_set
744
745
746
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
747
748

        if self.enable_layernorm:
749
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
750
            features = inputs.shape[-1]
751
            scale, ln_bias = _create_layernorm_parameters(
752
                self,
753
754
755
756
757
758
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
759
                input_dtype,
760
761
                self.dtype,
            )
762
763

            if not fuse_layernorm:
764
765
766
767
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
768
                    norm_type=self.layernorm_type,
769
770
771
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

787
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
788
789
790
791
792
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
793
        )
794
        if quantizer_set == noop_quantizer_set:
795
            kernel = kernel.astype(input_dtype)
796
797
798

        contract_ind = tuple(range(0, len(axis)))

799
        if fuse_layernorm:
800
            z = layernorm_dense(
801
802
803
804
                y,
                kernel,
                scale,
                ln_bias,
805
                norm_type=self.layernorm_type,
806
807
808
809
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
810
                kernel_axes=self.kernel_axes,
811
                quantizer_set=quantizer_set,
812
                transpose_batch_sequence=self.transpose_batch_sequence,
813
            )
814
        else:
815
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
816
817
818
819
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
820
                transpose_batch_sequence=self.transpose_batch_sequence,
821
822
823
824
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
825

826
        if self.enable_low_rank_adaptation:
827
828
829
830
831
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
832
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
833
            lora_a_kernel = self.param(
834
                "lora_a_kernel",
835
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
836
                lora_a_kernel_shape,
837
                self.dtype,
838
            ).astype(input_dtype)
839
840
841

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
842
            lora_b_kernel = self.param(
843
                "lora_b_kernel",
844
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
845
                lora_b_kernel_shape,
846
                self.dtype,
847
            ).astype(input_dtype)
848

849
850
851
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
852

853
854
        bias = None
        if self.use_bias:
855
856
857
858
859
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
860
            ).astype(input_dtype)
861
862

        if bias is not None:
863
864
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
865
866
867
868

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

869
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
870
        # z = z.reshape(*inputs.shape[: self.axis], *features)
871
        return z, ln_output  # dense_output, layer_norm_output
872
873
874
875
876


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
877
    consisting of 2 successive dense layer transformations, separated by given activations.
878
879
880
881

    Parameters
    ----------
    intermediate_dim: int, default = 2048
882
        Intermediate size to which input samples are projected.
883
    enable_layernorm: bool, default = True
884
        Indicate whether to enable layer normalization before dense layer transformation.
885
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
886
        Indicate the type of layer normalization.
887
    epsilon : float, default = 1e-6
888
        A value added to the denominator of layer normalization for numerical stability.
889
890
891
892
893
894
895
896
897
898
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
899
        Used for initializing scale factors :math:`\gamma`.
900
901
902
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
903
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
904
    scale_axes : Tuple[str, ...], default = ('embed', )
905
906
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
907
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
908
909
910
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
911
912
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
913
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
914
915
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
916
        Used for initializing the weights of both dense layer transformations.
917
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
918
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
919
        The name of axes used to shard the weights with a corresponding mesh for
920
        the weight of the first dense layer transformation.
921
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
922
        The name of axes used to shard the weights with a corresponding mesh for
923
        the weight of the second dense layer transformation.
924
    use_bias: bool, default = False
925
926
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
927
    bias_init: Initializer, default = flax.linen.initializers.zeros
928
929
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
930
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
931
        The name of axes used to shard bias with a corresponding mesh  for
932
        the weight of the first dense layer transformation.
933
        Only used when :attr:`use_bias=True`.
934
    bias_axes_2: Tuple[str, ...], default = ('embed',)
935
        The name of axes used to shard bias with a corresponding mesh  for
936
        the weight of the second dense layer transformation.
937
        Only used when :attr:`use_bias=True`.
938
    return_layernorm_output: bool, default = False
939
        Indicate whether to return the output of layer normalization.
940
        If set False, return None as the second tensor in outputs.
941
    activations: Sequence[Union[str, Callable]], default = ('gelu',)
942
        The sequence of activation functions to apply after the first dense layer transformation.
943
        Each activation has its own transformation layer.
944
945
946
947
    activation_params: dict, default = None
        The parameters needed(if any) by the activation functions specified in :attr:`activations`.
        At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
        need additional parameters.
948
949
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
950
    intermediate_dropout_rate: float, default = 0.0
951
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
952
953
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
954
    enable_low_rank_adaptation: bool, default = False
955
        Indicate whether to enable low rank adaptation for each dense layer.
956
957
958
959
960
961
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
962
    axis:  Union[Iterable[int], int], default = -1
963
        An integer tuple with axes to apply the transformation on.
964
965
966
967
968
969
970
971
972
973
974
975
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
976
977
978
979
980
    ffn1_ckpt_name: str = "ffn1"
        Checkpoint name for the output of the first fully-connected layer in the MLP block.
    ffn2_ckpt_name: str = "ffn2"
        Checkpoint name for the output of the second fully-connected layer in the MLP block.

981
982
983

    Optimization parameters
    -----------------------
984
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
985
        The data type used to allocate the initial parameters.
986
987
    transpose_batch_sequence: bool, default = False
        Indicate whether to transpose the batch and sequence dimensions of the input tensor.
988
989
    quantization_checkpoint_name: Optional[str], default = None
        The name for checkpointing quantizations.
990
991
992
993
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
994
    layernorm_type: str = "layernorm"
995
    epsilon: float = 1e-6
996
997
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
998
    scale_axes: Tuple[str, ...] = ("embed",)
999
    ln_bias_init: Initializer = nn.initializers.zeros
1000
    ln_bias_axes: Tuple[str, ...] = ("embed",)
1001
    kernel_init: Initializer = None
1002
1003
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
1004
1005
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
1006
1007
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
1008
1009
    return_layernorm_output: bool = False
    activations: Sequence[Union[str, Callable]] = ("gelu",)
1010
    activation_params: dict = None
1011
    intermediate_dropout_rng_name: str = "dropout"
1012
    intermediate_dropout_rate: float = 0.0
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1013
    intermediate_hidden_dropout_dims: Sequence[int] = ()
1014
1015
1016
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
1017
1018
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
1019
1020
1021
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
1022
1023
    ffn1_ckpt_name: str = "ffn1"
    ffn2_ckpt_name: str = "ffn2"
1024
    transpose_batch_sequence: bool = False
1025
    quantization_checkpoint_name: Optional[str] = None
1026
1027
1028

    def __post_init__(self):
        if self.kernel_init is None:
1029
            self.kernel_init = nn.initializers.variance_scaling(
1030
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
1031
            )
1032
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
1033
1034
            self.scale_init,
            self.zero_centered_gamma,
1035
        )
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1056
            If :attr:`return_layernorm_output=False`, then this would be None.
1057
        """
1058
1059
        assert self.axis == -1, "Only support axis == -1 at this moment"

1060
1061
1062
1063
1064
1065
        ffn1_quantizer_set = self.generate_quantizer_set(
            "_0", quantization_checkpoint_name=self.quantization_checkpoint_name
        )
        ffn2_quantizer_set = self.generate_quantizer_set(
            "_1", quantization_checkpoint_name=self.quantization_checkpoint_name
        )
1066

1067
        input_dtype = inputs.dtype
1068
1069
        ln_output = None

1070
1071
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1072
        fuse_layernorm = (
1073
            ffn1_quantizer_set != noop_quantizer_set
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
1084
            ("clamped_silu", "clamped_linear"),
1085
1086
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1087
        normalized_acts = []
1088
1089
1090
        for act in self.activations:
            if not isinstance(act, str):
                return False
1091
            normalized_acts.append(act.lower())
1092
        normalized_acts = tuple(
1093
1094
1095
            reversed(normalized_acts)
            if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
            else normalized_acts
1096
        )
1097

1098
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1099

1100
1101
1102
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1103
1104
        # LayerNorm
        if self.enable_layernorm:
1105
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1106

1107
1108
            features = inputs.shape[-1]

1109
            scale, ln_bias = _create_layernorm_parameters(
1110
                self,
1111
1112
1113
1114
1115
1116
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1117
                input_dtype,
1118
1119
                self.dtype,
            )
1120
1121

            if not fuse_layernorm:
1122
1123
1124
1125
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1126
                    norm_type=self.layernorm_type,
1127
1128
1129
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1144
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1145

1146
        num_activations = len(normalized_acts)
1147
1148
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1149
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1150
        kernel_1 = self.param(
1151
            "wi_kernel",
1152
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1153
1154
1155
            num_activations,
            -2,
            kernel_1_each_shape,
1156
            self.dtype,
1157
        )
1158

1159
        if ffn1_quantizer_set == noop_quantizer_set:
1160
            kernel_1 = kernel_1.astype(input_dtype)
1161

1162
1163
1164
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1165
        kernel_2 = self.param(
1166
            "wo_kernel",
1167
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1168
            kernel_2_shape,
1169
            self.dtype,
1170
        )
1171
        if ffn2_quantizer_set == noop_quantizer_set:
1172
            kernel_2 = kernel_2.astype(input_dtype)
1173

1174
        contract_ind = tuple(range(0, len(axis)))
1175

1176
        if self.use_bias:
1177
            bias_1_shape = (num_activations, self.intermediate_dim)
1178
            bias_1 = self.param(
1179
                "wi_bias",
1180
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1181
1182
                bias_1_shape,
                self.dtype,
1183
            ).astype(input_dtype)
1184
1185

            bias_2_shape = (hidden_size,)
1186
            bias_2 = self.param(
1187
                "wo_bias",
1188
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1189
1190
                bias_2_shape,
                self.dtype,
1191
            ).astype(input_dtype)
1192
1193
1194
1195
        else:
            bias_1 = None
            bias_2 = None

1196
        if use_fused_layernorm_mlp:
1197
            out = layernorm_mlp(
1198
1199
1200
1201
1202
1203
1204
1205
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1206
                norm_input_axes=self.layernorm_input_axes,
1207
1208
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1209
1210
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1211
1212
                ffn1_ckpt_name=self.ffn1_ckpt_name,
                ffn2_ckpt_name=self.ffn2_ckpt_name,
1213
                activation_type=normalized_acts,
1214
                activation_params=self.activation_params,
1215
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1216
                transpose_batch_sequence=self.transpose_batch_sequence,
1217
            )
1218
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1219
1220

        else:  # not use_fused_ln_geglu_mlp
1221
1222
            # DenseGeneral 1
            if fuse_layernorm:
1223
                x = layernorm_dense(
1224
1225
1226
1227
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1228
                    norm_type=self.layernorm_type,
1229
1230
1231
1232
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1233
                    kernel_axes=self.kernel_axes_1,
1234
                    quantizer_set=ffn1_quantizer_set,
1235
                    transpose_batch_sequence=self.transpose_batch_sequence,
1236
                )
1237
            else:
1238
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1239
1240
1241
1242
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1243
1244
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1245
                    quantizer_set=ffn1_quantizer_set,
1246
                    transpose_batch_sequence=self.transpose_batch_sequence,
1247
                )
1248

1249
            if self.enable_low_rank_adaptation:
1250
1251
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1252
1253
                    self.low_rank_adaptation_dim,
                )
1254
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1255
                wi_lora_a_kernel = self.param(
1256
                    "wi_lora_a_kernel",
1257
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1258
                    num_activations,
1259
1260
                    -2,
                    wi_lora_a_kernel_each_shape,
1261
                    self.dtype,
1262
                ).astype(input_dtype)
1263

1264
1265
1266
1267
1268
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1269
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1270
                wi_lora_b_kernel = self.param(
1271
                    "wi_lora_b_kernel",
1272
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1273
                    wi_lora_b_kernel_shape,
1274
                    self.dtype,
1275
                ).astype(input_dtype)
1276

1277
1278
1279
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1280
                    (num_activations, self.intermediate_dim),
1281
1282
1283
1284
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1285

1286
            if self.use_bias:
1287
                x += jnp.reshape(bias_1, bias_1_shape)
1288

1289
            x = checkpoint_name(x, self.ffn1_ckpt_name)
1290
            if is_act_implemented:
1291
                z = activation(x, normalized_acts)
1292
            else:
1293
                activations = []
1294
                x = jnp.split(x, num_activations, axis=-2)
1295
                for idx, act_fn in enumerate(normalized_acts):
1296
1297
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1298
                z = reduce(operator.mul, activations)
1299
                z = jnp.squeeze(z, axis=-2)
1300
            z = z.astype(input_dtype)
1301

1302
1303
1304
1305
1306
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1307

1308
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1309
            z = z.astype(input_dtype)
1310

1311
            # DenseGeneral 2
1312
            out = dense(
1313
1314
1315
1316
1317
1318
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1319
                transpose_batch_sequence=self.transpose_batch_sequence,
1320
            )
1321

1322
1323
1324
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1325
                wo_lora_a_kernel = self.param(
1326
                    "wo_lora_a_kernel",
1327
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1328
                    wo_lora_a_kernel_shape,
1329
                    self.dtype,
1330
                ).astype(input_dtype)
1331
1332
1333

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1334
                wo_lora_b_kernel = self.param(
1335
                    "wo_lora_b_kernel",
1336
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1337
                    wo_lora_b_kernel_shape,
1338
                    self.dtype,
1339
                ).astype(input_dtype)
1340

1341
1342
1343
1344
1345
1346
1347
1348
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1349

1350
            if self.use_bias:
1351
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1352

1353
            out = checkpoint_name(out, self.ffn2_ckpt_name)
1354

1355
        assert out.dtype == input_dtype
1356
        return out, ln_output  # Output, layer_norm_output