module.py 52.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
7
from functools import reduce
8
import operator
Alp Dener's avatar
Alp Dener committed
9
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
10
11

import numpy as np
12
import jax.numpy as jnp
13
14
15
from flax import linen as nn
from jax import lax
from jax import random as jax_random
16
from jax.ad_checkpoint import checkpoint_name
17

Alp Dener's avatar
Alp Dener committed
18
from ..dense import dense, _issue_batch_first_warning as _dense_warning
19
20
21

from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
Alp Dener's avatar
Alp Dener committed
22
23
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
24
from ..activation import activation
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
28
29
30
31
32
from ..cpp_extensions import (
    is_softmax_kernel_available,
    jax_scaled_softmax,
    jax_scaled_masked_softmax,
    jax_scaled_upper_triang_masked_softmax,
)
33
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34
from ..sharding import get_non_contracting_logical_axes
35
36
37

PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
38
39
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
40
41
42
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
43
44
45
46
47
48
49
50
51
52
53
54
55
56
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


57
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
58
59
60
61
62
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
63
64
65
    return nn.initializers.zeros


66
def _create_layernorm_parameters(
67
    module,
68
69
70
71
72
73
74
75
    norm_type,
    shape,
    scale_init,
    scale_axes,
    bias_init,
    bias_axes,
    input_dtype,
    dtype,
76
):
77
78
79
80
81
82
    scale = module.param(
        "scale",
        nn.with_logical_partitioning(scale_init, scale_axes),
        shape,
        dtype,
    ).astype(input_dtype)
83

84
85
    norm_type = canonicalize_norm_type(norm_type)
    if norm_type == "layernorm":
86
87
88
89
90
91
        bias = module.param(
            "ln_bias",
            nn.with_logical_partitioning(bias_init, bias_axes),
            shape,
            dtype,
        ).astype(input_dtype)
92
    else:
93
        assert norm_type == "rmsnorm"
94
95
96
97
98
99
100
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
101
    if fn_or_string == "linear":
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
116
117
118
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
119
120
121
122
123
124
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


125
126
127
128
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
129
    hidden_in_names = "ijklm"[: len(axis)]
130
    assert len(features) <= 5
131
132
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
133
134
135
136
137
138
139
140
141

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
142
143
144
145
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
146
147
148
149
150
151

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


152
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
153
154
    r"""
    Applies softmax over a mini-batch of inputs.
155
156
157
158
159
160
161
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
162
163
164
165

    Parameters
    ----------
    scale_factor : float, default = 1.0
166
167
168
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
169
170
171
172
173
174
175
176
177
178
179
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
180
        input_dtype = inputs.dtype
181
182
        logits = inputs

183
184
        # use primitives
        if is_softmax_kernel_available(
185
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
186
        ):
187
            if bias is not None:
188
                logits = logits + bias.astype(input_dtype)
189
190
191
192
193

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

194
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
195
        # use default jax based implementation
196
197
        else:
            if bias is not None:
198
                logits = logits + bias.astype(input_dtype)
199

200
201
202
203
204
205
            if self.softmax_type is SoftmaxType.SCALED:
                outputs = jax_scaled_softmax(logits, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_MASKED:
                outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
            elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
                outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
206
            else:
207
208
209
210
                raise ValueError(
                    f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
                    " SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
                )
211
        assert input_dtype == outputs.dtype
212
213
214
        return outputs


215
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
245
        A value added to the denominator of layer normalization for numerical stability.
246
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
247
        Indicate the type of layer normalization.
248
249
250
251
252
253
254
255
256
257
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
258
        Used for initializing scale factors :math:`\gamma`.
259
260
261
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
262
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
263
    scale_axes : Tuple[str, ...], default = ('embed', )
264
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
265
    bias_init : Initializer, default = flax.linen.initializers.zeros
266
267
268
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
269
270
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
271
        only used when :attr:`layernorm_type='layernorm'`.
272
273
274

    Optimization parameters
    -----------------------
275
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
276
        The data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
    transpose_batch_sequence : bool, default = False
278
279
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
280
281
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
282

283
    epsilon: float = 1e-6
284
    layernorm_type: str = "layernorm"
285
286
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
287
    scale_axes: Tuple[str, ...] = ("embed",)
288
    bias_init: Initializer = nn.initializers.zeros
289
    bias_axes: Tuple[str, ...] = ("embed",)
290
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
291
    transpose_batch_sequence: bool = False
292

293
    def __post_init__(self):
294
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
295
296
            self.scale_init,
            self.zero_centered_gamma,
297
        )
298
299
        super().__post_init__()

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
315
        input_dtype = x.dtype
316
317

        features = x.shape[-1]
318
        scale, ln_bias = _create_layernorm_parameters(
319
            self,
320
321
322
323
324
325
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
326
            input_dtype,
327
328
            self.dtype,
        )
329
        out = layernorm(
330
331
332
            x,
            scale,
            ln_bias,
333
            norm_type=self.layernorm_type,
334
335
336
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )
337
338
        assert out.dtype == input_dtype
        return out
339
340
341


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
342
343
344
345
    """
    Base class of transformer engine
    """

346
    def generate_quantizer_set(self, postfix: str = ""):
347
        """
348
        Generate a set of FP8 meta for a GEMM.
349
350
        """

351
352
353
354
        def generate_quantize_meta(quantizer_name: str):
            scale = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_scale",
355
356
                jnp.ones,
                (1,),
357
                jnp.float32,
358
359
360
361
362
363
364
365
366
367
            ).value
            amax_history = self.variable(
                QuantizeConfig.COLLECTION_NAME,
                f"{quantizer_name}{postfix}_amax_history",
                jnp.zeros,
                (QuantizeConfig.AMAX_HISTORY_LEN,),
                jnp.float32,
            ).value
            return QuantizeMeta(scale=scale, amax_history=amax_history)

368
        if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
369
370
371
372
373
374
375
            x_meta = generate_quantize_meta("x")
            kernel_meta = generate_quantize_meta("kernel")
            grad_meta = generate_quantize_meta("grad")
            quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
            kwargs = {"quantize_meta_set": quantize_meta_set}
        else:
            kwargs = {}
376

377
378
        quantizer_set = QuantizerFactory.create_set(**kwargs)
        return quantizer_set
379
380
381


class DenseGeneral(TransformerEngineBase):
382
    r"""
383
    Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
384
385
386
387

    Parameters
    ----------
    features : Union[Iterable[int], int]
388
        The hidden size of each output sample.
389
390
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
391
392
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
393
    kernel_axes : Tuple[str, ...], default = ()
394
        The name of axes used to shard the weights with a corresponding mesh.
395
    use_bias: bool, default = False
396
397
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
398
    bias_init: Initializer, default = flax.linen.initializers.zeros
399
400
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
401
    bias_axes: Tuple[str, ...], default = ()
402
403
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
404
    enable_low_rank_adaptation: bool, default = False
405
        Indicate whether to enable low rank adaptation for each dense layer.
406
407
408
409
410
411
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
412
    axis:  Union[Iterable[int], int], default = -1
413
        An integer tuple with axes to apply the transformation on.
414
415
416
417
    input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
418
419
    sequence_parallel_output: bool, default = False
        Produce a sequence-parallel output with the first non-batch dimension sharded over
420
421
422

    Optimization parameters
    -----------------------
423
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
424
        The data type used to allocate the initial parameters.
425
    transpose_batch_sequence : bool, default = True
426
427
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
428
429
430
431
432
433
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
434
    use_bias: bool = True
435
436
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
437
438
439
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
440
441
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
442
    transpose_batch_sequence: bool = False
443
    input_axes: Tuple[str, ...] = ()
444
    sequence_parallel_output: bool = False
445
446

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
447
448
449
450
451
452
        if self.transpose_batch_sequence:
            _dense_warning(
                "TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
453
        if self.kernel_init is None:
454
            self.kernel_init = nn.initializers.variance_scaling(
455
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
456
            )
457
458
459
460
461
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
462
        Apply the dense layer transformation to the input.
463
464
465
466
467
468
469
470
471
472
473

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
474

475
        input_dtype = inputs.dtype
476
477
478
479
480
481
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
482
483
484
485
486
487

        if self.kernel_axes:
            assert len(kernel_shape) == len(self.kernel_axes), (
                "Expected len(kernel_shape) to match len(kernel_axes),"
                f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
            )
488
489
490
491
492
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
493
        )
494

495
        if not QuantizeConfig.is_fp8_enabled():
496
            kernel = kernel.astype(input_dtype)
497
498

        if self.use_bias:
499
500
501
502
503
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
504
            ).astype(input_dtype)
505
506
507
        else:
            bias = None

508
        quantizer_set = self.generate_quantizer_set()
509
        contract_ind = tuple(range(0, len(axis)))
510
        y = dense(
511
512
513
514
515
516
            inputs,
            kernel,
            contracting_dims=(axis, contract_ind),
            input_axes=self.input_axes,
            kernel_axes=self.kernel_axes,
            quantizer_set=quantizer_set,
517
            sequence_parallel_output=self.sequence_parallel_output,
518
        )
519

520
        if self.enable_low_rank_adaptation:
521
522
523
524
525
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
526
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
527
            lora_a_kernel = self.param(
528
                "lora_a_kernel",
529
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
530
                lora_a_kernel_shape,
531
                self.dtype,
532
            ).astype(input_dtype)
533
534
535

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
536
            lora_b_kernel = self.param(
537
                "lora_b_kernel",
538
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
539
                lora_b_kernel_shape,
540
                self.dtype,
541
            ).astype(input_dtype)
542

543
544
545
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
546

547
        if bias is not None:
548
549
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
550
551

        assert y.dtype == input_dtype
552
553
554
555
556
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
557
    Applies layer normalization followed by dense layer transformation to the incoming data.
558
559
560
561

    Parameters
    ----------
    features : Union[Iterable[int], int]
562
        The hidden size of each output sample.
563
    enable_layernorm: bool, default = True
564
        Indicate whether to enable layer normalization before dense layer transformation.
565
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
566
        Indicate the type of layer normalization.
567
    epsilon : float, default = 1e-6
568
        A value added to the denominator of layer normalization for numerical stability.
569
570
571
572
573
574
575
576
577
578
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
579
        Used for initializing scale factors :math:`\gamma`.
580
581
582
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
583
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
584
    scale_axes : Tuple[str, ...], default = ('embed', )
585
586
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
587
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
588
589
590
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
591
592
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
593
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
594
595
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
596
597
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
598
    kernel_axes : Tuple[str, ...], default = ()
599
        The name of axes used to shard the weights with a corresponding mesh.
600
    use_bias: bool, default = False
601
602
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
603
    bias_init: Initializer, default = flax.linen.initializers.zeros
604
605
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
606
    bias_axes: Tuple[str, ...], default = ()
607
608
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
609
    return_layernorm_output: bool, default = True
610
        Indicate whether to return the output of layer normalization.
611
        If set False, return None as the second tensor in outputs.
612
    enable_low_rank_adaptation: bool, default = False
613
        Indicate whether to enable low rank adaptation for each dense layer.
614
615
616
617
618
619
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
620
    axis:  Union[Iterable[int], int], default = -1
621
        An integer tuple with axes to apply the transformation on.
622
623
624
625
626
627
628
629
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
630
631
632

    Optimization parameters
    -----------------------
633
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
634
        The data type used to allocate the initial parameters.
635
    transpose_batch_sequence : bool, default = True
636
637
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
638
639
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
640
        The factor to scale the output from `DenseGeneral`. It should be a float
641
642
643
644
645
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
646
    layernorm_type: str = "layernorm"
647
    epsilon: float = 1e-6
648
649
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
650
    scale_axes: Tuple[str, ...] = ("embed",)
651
    ln_bias_init: Initializer = nn.initializers.zeros
652
    ln_bias_axes: Tuple[str, ...] = ("embed",)
653
654
655
656
657
658
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
659
660
661
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
662
663
664
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
665
666
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
667
668
669
    depth_scaling: float = None

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
670
671
672
673
674
675
        if self.transpose_batch_sequence:
            _ln_dense_warning(
                "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
                "inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
                "Use sequence-first inputs at your own discretion."
            )
676
        if self.kernel_init is None:
677
            self.kernel_init = nn.initializers.variance_scaling(
678
679
680
                1.0,
                "fan_in",
                "truncated_normal",
681
                dtype=self.dtype,
682
            )
683
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
684
685
            self.scale_init,
            self.zero_centered_gamma,
686
        )
687
        self.quantizer_set = QuantizerFactory.create_set()
688
689
690
691
692
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
693
        Apply layer normalization to the input followed by a dense layer transformation.
694
695
696
697
698
699
700
701
702
703
704
705

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
706
            If :attr:`return_layernorm_output=False`, then this would be None.
707
        """
708
        assert self.axis == -1, "Only support axis = =-1 at this moment"
709

710
        input_dtype = inputs.dtype
711
712
        ln_output = None

713
714
        quantizer_set = self.generate_quantizer_set()

715
        fuse_layernorm = (
716
            QuantizeConfig.is_fp8_enabled()
717
718
719
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
720
721

        if self.enable_layernorm:
722
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
723
            features = inputs.shape[-1]
724
            scale, ln_bias = _create_layernorm_parameters(
725
                self,
726
727
728
729
730
731
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
732
                input_dtype,
733
734
                self.dtype,
            )
735
736

            if not fuse_layernorm:
737
738
739
740
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
741
                    norm_type=self.layernorm_type,
742
743
744
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

760
        kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
761
762
763
764
765
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_shape,
            self.dtype,
766
        )
767
        if not QuantizeConfig.is_fp8_enabled():
768
            kernel = kernel.astype(input_dtype)
769
770
771

        contract_ind = tuple(range(0, len(axis)))

772
        if fuse_layernorm:
773
            z = layernorm_dense(
774
775
776
777
                y,
                kernel,
                scale,
                ln_bias,
778
                norm_type=self.layernorm_type,
779
780
781
782
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
783
                kernel_axes=self.kernel_axes,
784
                quantizer_set=quantizer_set,
785
            )
786
        else:
787
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
788
789
790
791
792
793
794
795
            z = dense(
                y,
                kernel,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_input_axes,
                kernel_axes=self.kernel_axes,
                quantizer_set=quantizer_set,
            )
796

797
        if self.enable_low_rank_adaptation:
798
799
800
801
802
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
803
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
804
            lora_a_kernel = self.param(
805
                "lora_a_kernel",
806
                nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
807
                lora_a_kernel_shape,
808
                self.dtype,
809
            ).astype(input_dtype)
810
811
812

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
813
            lora_b_kernel = self.param(
814
                "lora_b_kernel",
815
                nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
816
                lora_b_kernel_shape,
817
                self.dtype,
818
            ).astype(input_dtype)
819

820
821
822
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
823

824
825
        bias = None
        if self.use_bias:
826
827
828
829
830
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                features,
                self.dtype,
831
            ).astype(input_dtype)
832
833

        if bias is not None:
834
835
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
836
837
838
839

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

840
        assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
841
        # z = z.reshape(*inputs.shape[: self.axis], *features)
842
        return z, ln_output  # dense_output, layer_norm_output
843
844
845
846
847


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
848
    consisting of 2 successive dense layer transformations, separated by given activations.
849
850
851
852

    Parameters
    ----------
    intermediate_dim: int, default = 2048
853
        Intermediate size to which input samples are projected.
854
    enable_layernorm: bool, default = True
855
        Indicate whether to enable layer normalization before dense layer transformation.
856
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
857
        Indicate the type of layer normalization.
858
    epsilon : float, default = 1e-6
859
        A value added to the denominator of layer normalization for numerical stability.
860
861
862
863
864
865
866
867
868
869
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
870
        Used for initializing scale factors :math:`\gamma`.
871
872
873
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
874
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
875
    scale_axes : Tuple[str, ...], default = ('embed', )
876
877
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
878
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
879
880
881
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
882
883
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
884
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
885
886
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
887
        Used for initializing the weights of both dense layer transformations.
888
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
889
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
890
        The name of axes used to shard the weights with a corresponding mesh for
891
        the weight of the first dense layer transformation.
892
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
893
        The name of axes used to shard the weights with a corresponding mesh for
894
        the weight of the second dense layer transformation.
895
    use_bias: bool, default = False
896
897
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
898
    bias_init: Initializer, default = flax.linen.initializers.zeros
899
900
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
901
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
902
        The name of axes used to shard bias with a corresponding mesh  for
903
        the weight of the first dense layer transformation.
904
        Only used when :attr:`use_bias=True`.
905
    bias_axes_2: Tuple[str, ...], default = ('embed',)
906
        The name of axes used to shard bias with a corresponding mesh  for
907
        the weight of the second dense layer transformation.
908
        Only used when :attr:`use_bias=True`.
909
    return_layernorm_output: bool, default = True
910
        Indicate whether to return the output of layer normalization.
911
912
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
913
        The sequence of activation functions to apply after the first dense layer transformation.
914
        Each activation has its own transformation layer.
915
916
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
917
    intermediate_dropout_rate: float, default = 0.1
918
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
919
920
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
921
    enable_low_rank_adaptation: bool, default = False
922
        Indicate whether to enable low rank adaptation for each dense layer.
923
924
925
926
927
928
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
929
    axis:  Union[Iterable[int], int], default = -1
930
        An integer tuple with axes to apply the transformation on.
931
932
933
934
935
936
937
938
939
940
941
942
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
943
944
945

    Optimization parameters
    -----------------------
946
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
947
        The data type used to allocate the initial parameters.
948
    transpose_batch_sequence : bool, default = True
949
950
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
951
952
953
954
955
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
956
    layernorm_type: str = "layernorm"
957
    epsilon: float = 1e-6
958
959
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
960
    scale_axes: Tuple[str, ...] = ("embed",)
961
    ln_bias_init: Initializer = nn.initializers.zeros
962
    ln_bias_axes: Tuple[str, ...] = ("embed",)
963
    kernel_init: Initializer = None
964
965
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
966
967
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
968
969
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
970
    return_layernorm_output: bool = True
971
972
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
973
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
974
    intermediate_hidden_dropout_dims: Sequence[int] = ()
975
976
977
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
978
979
980
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
981
982
983
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
984
985

    def __post_init__(self):
Alp Dener's avatar
Alp Dener committed
986
987
988
989
990
991
        if self.transpose_batch_sequence:
            _ln_mlp_warning(
                "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
                "and may produce incorrect results when `transpose_batch_sequence=True`. Use "
                "sequence-first inputs at your own discretion."
            )
992
        if self.kernel_init is None:
993
            self.kernel_init = nn.initializers.variance_scaling(
994
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
995
            )
996
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
997
998
            self.scale_init,
            self.zero_centered_gamma,
999
        )
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
1020
            If :attr:`return_layernorm_output=False`, then this would be None.
1021
        """
1022
1023
        assert self.axis == -1, "Only support axis == -1 at this moment"

1024
1025
        ffn1_quantizer_set = self.generate_quantizer_set("_0")
        ffn2_quantizer_set = self.generate_quantizer_set("_1")
1026

1027
        input_dtype = inputs.dtype
1028
1029
        ln_output = None

1030
1031
        # TODO(Phuong): use fuse_layernorm for high-precision
        # when NoOpQuantizer and Tensor are implemented
1032
        fuse_layernorm = (
1033
            QuantizeConfig.is_fp8_enabled()
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1046
        normalized_acts = []
1047
1048
1049
        for act in self.activations:
            if not isinstance(act, str):
                return False
1050
            normalized_acts.append(act.lower())
1051
        normalized_acts = tuple(
1052
1053
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1054

1055
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1056

1057
1058
1059
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1060
1061
        # LayerNorm
        if self.enable_layernorm:
1062
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1063

1064
1065
            features = inputs.shape[-1]

1066
            scale, ln_bias = _create_layernorm_parameters(
1067
                self,
1068
1069
1070
1071
1072
1073
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
1074
                input_dtype,
1075
1076
                self.dtype,
            )
1077
1078

            if not fuse_layernorm:
1079
1080
1081
1082
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
1083
                    norm_type=self.layernorm_type,
1084
1085
1086
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1101
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1102

1103
        num_activations = len(normalized_acts)
1104
1105
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1106
        kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
1107
        kernel_1 = self.param(
1108
            "wi_kernel",
1109
            nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
1110
1111
1112
            num_activations,
            -2,
            kernel_1_each_shape,
1113
            self.dtype,
1114
        )
1115

1116
        if not QuantizeConfig.is_fp8_enabled():
1117
            kernel_1 = kernel_1.astype(input_dtype)
1118

1119
1120
1121
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
1122
        kernel_2 = self.param(
1123
            "wo_kernel",
1124
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
1125
            kernel_2_shape,
1126
            self.dtype,
1127
        )
1128
        if not QuantizeConfig.is_fp8_enabled():
1129
            kernel_2 = kernel_2.astype(input_dtype)
1130

1131
        contract_ind = tuple(range(0, len(axis)))
1132

1133
        if self.use_bias:
1134
            bias_1_shape = (num_activations, self.intermediate_dim)
1135
            bias_1 = self.param(
1136
                "wi_bias",
1137
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
1138
1139
                bias_1_shape,
                self.dtype,
1140
            ).astype(input_dtype)
1141
1142

            bias_2_shape = (hidden_size,)
1143
            bias_2 = self.param(
1144
                "wo_bias",
1145
                nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
1146
1147
                bias_2_shape,
                self.dtype,
1148
            ).astype(input_dtype)
1149
1150
1151
1152
        else:
            bias_1 = None
            bias_2 = None

1153
1154
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1155

1156
        if use_fused_layernorm_mlp:
1157
            out = layernorm_mlp(
1158
1159
1160
1161
1162
1163
1164
1165
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
1166
                norm_input_axes=self.layernorm_input_axes,
1167
1168
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
1169
1170
                kernel_1_axes=self.kernel_axes_1,
                kernel_2_axes=self.kernel_axes_2,
1171
1172
1173
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
1174
                quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
1175
            )
1176
            out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
1177
1178

        else:  # not use_fused_ln_geglu_mlp
1179
1180
            # DenseGeneral 1
            if fuse_layernorm:
1181
                x = layernorm_dense(
1182
1183
1184
1185
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
1186
                    norm_type=self.layernorm_type,
1187
1188
1189
1190
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
1191
                    kernel_axes=self.kernel_axes_1,
1192
                    quantizer_set=ffn1_quantizer_set,
1193
                )
1194
            else:
1195
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1196
1197
1198
1199
                x = dense(
                    y,
                    kernel_1,
                    contracting_dims=(axis, contract_ind),
1200
1201
                    input_axes=self.dot_1_input_axes,
                    kernel_axes=self.kernel_axes_1,
1202
                    quantizer_set=ffn1_quantizer_set,
1203
                )
1204
1205
1206
1207
1208
1209
1210
1211
1212

            if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
                dot_1_output_axes = (
                    *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
                    *get_non_contracting_logical_axes(
                        kernel_1.ndim, self.kernel_axes_1, contract_ind
                    ),
                )
                x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
1213

1214
            if self.enable_low_rank_adaptation:
1215
1216
                wi_lora_a_kernel_each_shape = (
                    kernel_1_each_shape[: len(axis)],
1217
1218
                    self.low_rank_adaptation_dim,
                )
1219
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
1220
                wi_lora_a_kernel = self.param(
1221
                    "wi_lora_a_kernel",
1222
                    nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
1223
                    num_activations,
1224
1225
                    -2,
                    wi_lora_a_kernel_each_shape,
1226
                    self.dtype,
1227
                ).astype(input_dtype)
1228

1229
1230
1231
1232
1233
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1234
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1235
                wi_lora_b_kernel = self.param(
1236
                    "wi_lora_b_kernel",
1237
                    nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
1238
                    wi_lora_b_kernel_shape,
1239
                    self.dtype,
1240
                ).astype(input_dtype)
1241

1242
1243
1244
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
1245
                    (num_activations, self.intermediate_dim),
1246
1247
1248
1249
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1250

1251
            if self.use_bias:
1252
                x += jnp.reshape(bias_1, bias_1_shape)
1253

1254
            x = checkpoint_name(x, ffn1_ckpt_name)
1255
            if is_act_implemented:
1256
                z = activation(x, normalized_acts)
1257
            else:
1258
                activations = []
1259
                x = jnp.split(x, num_activations, axis=-2)
1260
                for idx, act_fn in enumerate(normalized_acts):
1261
1262
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
1263
                z = reduce(operator.mul, activations)
1264
                z = jnp.squeeze(z, axis=-2)
1265
            z = z.astype(input_dtype)
1266

1267
1268
1269
1270
1271
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1272

1273
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1274
            z = z.astype(input_dtype)
1275

1276
            # DenseGeneral 2
1277
            out = dense(
1278
1279
1280
1281
1282
1283
                z,
                kernel_2,
                contracting_dims=(axis, contract_ind),
                input_axes=self.dot_2_input_axes,
                kernel_axes=self.kernel_axes_2,
                quantizer_set=ffn2_quantizer_set,
1284
            )
1285

1286
1287
1288
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1289
                wo_lora_a_kernel = self.param(
1290
                    "wo_lora_a_kernel",
1291
                    nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
1292
                    wo_lora_a_kernel_shape,
1293
                    self.dtype,
1294
                ).astype(input_dtype)
1295
1296
1297

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1298
                wo_lora_b_kernel = self.param(
1299
                    "wo_lora_b_kernel",
1300
                    nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
1301
                    wo_lora_b_kernel_shape,
1302
                    self.dtype,
1303
                ).astype(input_dtype)
1304

1305
1306
1307
1308
1309
1310
1311
1312
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1313

1314
            if self.use_bias:
1315
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1316

1317
            out = checkpoint_name(out, ffn2_ckpt_name)
1318

1319
        assert out.dtype == input_dtype
1320
        return out, ln_output  # Output, layner_norm_output