"tests/vscode:/vscode.git/clone" did not exist on "d10f8e1d43bfb0656b6848ad0c681ecbdec812d6"
module.py 51.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import numpy as np
12
import jax.numpy as jnp
13
14
15
16
17
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
18
from jax.ad_checkpoint import checkpoint_name
19

20
21
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
22
23
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
24
from ..layernorm_mlp import fused_layernorm_fp8_mlp, activation_lu
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
from ..cpp_extensions import is_softmax_kernel_available
28
29
30
31
32

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
33
34
35
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
36
37
38
39
40
41
42
43
44
45
46
47
48
49
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


50
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
51
52
53
54
55
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
56
57
58
    return nn.initializers.zeros


59
def _create_layernorm_parameters(
60
    layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype
61
):
62
63
64
    scale = nn_partitioning.param_with_axes(
        "scale", scale_init, shape, weight_dtype, axes=scale_axes
    )
65
    scale = scale.astype(dtype)
66
67

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
68
    if layernorm_type == "layernorm":
69
70
71
        bias = nn_partitioning.param_with_axes(
            "ln_bias", bias_init, shape, weight_dtype, axes=bias_axes
        )
72
        bias = bias.astype(dtype)
73
    else:
74
        assert layernorm_type == "rmsnorm"
75
76
77
78
79
80
81
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
82
    if fn_or_string == "linear":
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
97
98
99
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
100
101
102
103
104
105
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


106
107
108
109
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
110
    hidden_in_names = "ijklm"[: len(axis)]
111
    assert len(features) <= 5
112
113
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
114
115
116
117
118
119
120
121
122

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
123
124
125
126
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
127
128
129
130
131
132

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


133
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
134
135
    r"""
    Applies softmax over a mini-batch of inputs.
136
137
138
139
140
141
142
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
143
144
145
146

    Parameters
    ----------
    scale_factor : float, default = 1.0
147
148
149
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

164
165
166
        if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype
        ):
167
168
169
170
171
172
173
174

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

175
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
176
177
178
        else:
            attention_bias = None
            if mask is not None:
179
180
181
182
183
                attention_bias = lax.select(
                    mask > 0,
                    jnp.full(mask.shape, -1e10).astype(dtype),
                    jnp.full(mask.shape, 0.0).astype(dtype),
                )
184
185
186
187
188
189
190
191
192

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
193
194
195
            if is_softmax_kernel_available(
                SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype
            ):
196
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
197
            else:
198
                outputs = jax_nn.softmax(logits * self.scale_factor)
199
200
201
202

        return outputs


203
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
233
        A value added to the denominator of layer normalization for numerical stability.
234
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
235
        Indicate the type of layer normalization.
236
237
238
239
240
241
242
243
244
245
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
246
        Used for initializing scale factors :math:`\gamma`.
247
248
249
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
250
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
251
    scale_axes : Tuple[str, ...], default = ('embed', )
252
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
253
    bias_init : Initializer, default = flax.linen.initializers.zeros
254
255
256
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
257
258
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
259
        only used when :attr:`layernorm_type='layernorm'`.
260
261
262

    Optimization parameters
    -----------------------
263
264
265
266
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
267
    transpose_batch_sequence : bool, default = False
268
269
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
270
271
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
272

273
    epsilon: float = 1e-6
274
    layernorm_type: str = "layernorm"
275
276
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
277
    scale_axes: Tuple[str, ...] = ("embed",)
278
    bias_init: Initializer = nn.initializers.zeros
279
    bias_axes: Tuple[str, ...] = ("embed",)
280
    dtype: DType = jnp.float32
281
    weight_dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
282
    transpose_batch_sequence: bool = False
283

284
    def __post_init__(self):
285
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
286
287
            self.scale_init,
            self.zero_centered_gamma,
288
        )
289
290
        super().__post_init__()

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
306
        x = x.astype(self.dtype)
307
308

        features = x.shape[-1]
309
310
311
312
313
314
315
316
        scale, ln_bias = _create_layernorm_parameters(
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
            self.dtype,
317
            self.weight_dtype,
318
319
320
321
322
323
324
325
326
327
328
329
        )
        return layernorm(
            x,
            scale,
            ln_bias,
            layernorm_type=self.layernorm_type,
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
330
331
332
333
334
    """
    Base class of transformer engine
    """

    @staticmethod
335
    def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage:
336
        """
337
        Generate a set of FP8 meta for a GEMM.
338
339
        """

340
341
342
343
344
        input_name_post_fix = f"_i_{postfix}"
        weight_name_post_fix = f"_w_{postfix}"
        grad_name_post_fix = f"_g_{postfix}"

        def generate_a_set(target_postfix):
345
346
347
348
349
350
351
352
            amax = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
                jnp.zeros,
                (FP8Helper.AMAX_HISTORY_LEN,),
                jnp.float32,
                axes=(None,),
            )
353
354
355
356

            scale = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
357
358
                jnp.ones,
                (1,),
359
                jnp.float32,
360
361
                axes=(None,),
            )
362
363
364
365
366
367

            return amax.value, scale.value

        input_amax, input_scale = generate_a_set(input_name_post_fix)
        weight_amax, weight_scale = generate_a_set(weight_name_post_fix)
        grad_amax, grad_scale = generate_a_set(grad_name_post_fix)
368

369
370
371
        return FP8MetaPackage(
            input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale
        )
372
373
374


class DenseGeneral(TransformerEngineBase):
375
376
    r"""
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`.
377
378
379
380

    Parameters
    ----------
    features : Union[Iterable[int], int]
381
        The hidden size of each output sample.
382
383
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
384
385
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
386
    kernel_axes : Tuple[str, ...], default = ()
387
        The name of axes used to shard the weights with a corresponding mesh.
388
    use_bias: bool, default = False
389
390
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
391
    bias_init: Initializer, default = flax.linen.initializers.zeros
392
393
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
394
    bias_axes: Tuple[str, ...], default = ()
395
396
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
397
398
399
400
401
402
403
404
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
405
    axis:  Union[Iterable[int], int], default = -1
406
        An integer tuple with axes to apply the transformation on.
407
408
409

    Optimization parameters
    -----------------------
410
411
412
413
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
414
    transpose_batch_sequence : bool, default = True
415
416
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
417
418
419
420
421
422
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
423
    use_bias: bool = True
424
425
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
426
427
428
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
429
430
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
431
    weight_dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
432
    transpose_batch_sequence: bool = False
433
434
435

    def __post_init__(self):
        if self.kernel_init is None:
436
            self.kernel_init = nn.initializers.variance_scaling(
437
                1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
438
            )
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
456

457
458
459
460
461
462
463
464
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
465
        kernel = nn_partitioning.param_with_axes(
466
            "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
467
        )
468
        kernel = kernel.astype(self.dtype)
469
470

        if self.use_bias:
471
            bias = nn_partitioning.param_with_axes(
472
                "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
473
            )
474
            bias = bias.astype(self.dtype)
475
476
477
478
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
479
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
480
        if FP8Helper.is_fp8_enabled():
481
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
482

483
484
485
        y = type_safe_dot_general(
            inputs, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
        )
486

487
        if self.enable_low_rank_adaptation:
488
489
490
491
492
493
494
495
496
497
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
498
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
499
500
501
502
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
503
                self.weight_dtype,
504
505
                axes=lora_a_kernel_axes,
            )
506
507
508
509
510
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
511
512
513
514
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
515
                self.weight_dtype,
516
517
                axes=lora_b_kernel_axes,
            )
518
519
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

520
521
522
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
523

524
        if bias is not None:
525
526
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
527
528
529
530
531
532
533
534
535
536
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
537
        The hidden size of each output sample.
538
    enable_layernorm: bool, default = True
539
        Indicate whether to enable layer normalization before linear transformation.
540
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
541
        Indicate the type of layer normalization.
542
    epsilon : float, default = 1e-6
543
        A value added to the denominator of layer normalization for numerical stability.
544
545
546
547
548
549
550
551
552
553
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
554
        Used for initializing scale factors :math:`\gamma`.
555
556
557
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
558
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
559
    scale_axes : Tuple[str, ...], default = ('embed', )
560
561
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
562
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
563
564
565
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
566
567
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
568
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
569
570
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
571
572
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
573
    kernel_axes : Tuple[str, ...], default = ()
574
        The name of axes used to shard the weights with a corresponding mesh.
575
    use_bias: bool, default = False
576
577
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
578
    bias_init: Initializer, default = flax.linen.initializers.zeros
579
580
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
581
    bias_axes: Tuple[str, ...], default = ()
582
583
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
584
    return_layernorm_output: bool, default = True
585
        Indicate whether to return the output of layer normalization.
586
        If set False, return None as the second tensor in outputs.
587
588
589
590
591
592
593
594
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
595
    axis:  Union[Iterable[int], int], default = -1
596
        An integer tuple with axes to apply the transformation on.
597
598
599
600
601
602
603
604
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
605
606
607

    Optimization parameters
    -----------------------
608
609
610
611
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
612
    transpose_batch_sequence : bool, default = True
613
614
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
615
616
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
617
        The factor to scale the output from `DenseGeneral`. It should be a float
618
619
620
621
622
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
623
    layernorm_type: str = "layernorm"
624
    epsilon: float = 1e-6
625
626
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
627
    scale_axes: Tuple[str, ...] = ("embed",)
628
    ln_bias_init: Initializer = nn.initializers.zeros
629
    ln_bias_axes: Tuple[str, ...] = ("embed",)
630
631
632
633
634
635
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
636
637
638
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
639
640
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
641
    weight_dtype: DType = jnp.float32
642
    transpose_batch_sequence: bool = True
643
644
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
645
646
647
648
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
649
            self.kernel_init = nn.initializers.variance_scaling(
650
651
652
653
                1.0,
                "fan_in",
                "truncated_normal",
                dtype=self.weight_dtype,
654
            )
655
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
656
657
            self.scale_init,
            self.zero_centered_gamma,
658
        )
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
677
            If :attr:`return_layernorm_output=False`, then this would be None.
678
        """
679

680
681
        ln_output = None

682
683
684
685
686
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
687
        inputs = inputs.astype(self.dtype)
688
689

        if self.enable_layernorm:
690
691
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)

692
            assert self.axis == -1  # Only support axis = =-1 at this moment
693
694
            features = inputs.shape[-1]

695
696
697
698
699
700
701
702
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
                self.dtype,
703
                self.weight_dtype,
704
            )
705
706

            if not fuse_layernorm:
707
708
709
710
711
712
713
714
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
732
        kernel = nn_partitioning.param_with_axes(
733
            "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
734
        )
735
        kernel = kernel.astype(self.dtype)
736
737
738

        contract_ind = tuple(range(0, len(axis)))

739
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
740
        if FP8Helper.is_fp8_enabled():
741
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
742
743

        if fuse_layernorm:
744
745
746
747
748
749
750
751
752
753
754
755
            z = layernorm_fp8_dot(
                y,
                kernel,
                scale,
                ln_bias,
                fp8_meta_pkg,
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
            )
756
        else:
757
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
758
759
760
            z = type_safe_dot_general(
                y, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
761

762
        if self.enable_low_rank_adaptation:
763
764
765
766
767
768
769
770
771
772
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
773
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
774
775
776
777
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
778
                self.weight_dtype,
779
780
                axes=lora_a_kernel_axes,
            )
781
782
783
784
785
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
786
787
788
789
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
790
                self.weight_dtype,
791
792
                axes=lora_b_kernel_axes,
            )
793
794
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

795
796
797
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
798

799
800
        bias = None
        if self.use_bias:
801
            bias = nn_partitioning.param_with_axes(
802
                "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
803
            )
804
            bias = bias.astype(self.dtype)
805
806

        if bias is not None:
807
808
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
809
810
811
812

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

813
        return z, ln_output  # dense_output, layer_norm_output
814
815
816
817
818
819
820
821
822
823


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
824
        Intermediate size to which input samples are projected.
825
    enable_layernorm: bool, default = True
826
        Indicate whether to enable layer normalization before linear transformation.
827
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
828
        Indicate the type of layer normalization.
829
    epsilon : float, default = 1e-6
830
        A value added to the denominator of layer normalization for numerical stability.
831
832
833
834
835
836
837
838
839
840
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
841
        Used for initializing scale factors :math:`\gamma`.
842
843
844
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
845
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
846
    scale_axes : Tuple[str, ...], default = ('embed', )
847
848
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
849
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
850
851
852
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
853
854
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
855
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
856
857
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
858
859
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
860
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
861
        The name of axes used to shard the weights with a corresponding mesh for
862
863
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
864
        The name of axes used to shard the weights with a corresponding mesh for
865
866
        the weight of the second linear transformations.
    use_bias: bool, default = False
867
868
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
869
    bias_init: Initializer, default = flax.linen.initializers.zeros
870
871
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
872
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
873
        The name of axes used to shard bias with a corresponding mesh  for
874
        the weight of the first linear transformations.
875
        Only used when :attr:`use_bias=True`.
876
    bias_axes_2: Tuple[str, ...], default = ('embed',)
877
        The name of axes used to shard bias with a corresponding mesh  for
878
        the weight of the second linear transformations.
879
        Only used when :attr:`use_bias=True`.
880
    return_layernorm_output: bool, default = True
881
        Indicate whether to return the output of layer normalization.
882
883
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
884
        The sequence of activation functions to apply after the first linear transformation.
885
        Each activation has its own transformation layer.
886
887
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
888
    intermediate_dropout_rate: float, default = 0.1
889
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
890
891
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
892
893
894
895
896
897
898
899
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
900
    axis:  Union[Iterable[int], int], default = -1
901
        An integer tuple with axes to apply the transformation on.
902
903
904
905
906
907
908
909
910
911
912
913
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
914
915
916

    Optimization parameters
    -----------------------
917
918
919
920
    dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type used for computation.
    weight_dtype: jax.numpy.dtype, default  = jax.numpy.float32
        The data type of the module parameters.
921
    transpose_batch_sequence : bool, default = True
922
923
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
924
925
926
927
928
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
929
    layernorm_type: str = "layernorm"
930
    epsilon: float = 1e-6
931
932
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
933
    scale_axes: Tuple[str, ...] = ("embed",)
934
    ln_bias_init: Initializer = nn.initializers.zeros
935
    ln_bias_axes: Tuple[str, ...] = ("embed",)
936
    kernel_init: Initializer = None
937
938
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
939
940
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
941
942
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
943
    return_layernorm_output: bool = True
944
945
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
946
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
947
    intermediate_hidden_dropout_dims: Sequence[int] = ()
948
949
950
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
951
952
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
953
    weight_dtype: DType = jnp.float32
954
    transpose_batch_sequence: bool = True
955
956
957
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
958
959
960

    def __post_init__(self):
        if self.kernel_init is None:
961
            self.kernel_init = nn.initializers.variance_scaling(
962
                1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
963
            )
964
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
965
966
            self.scale_init,
            self.zero_centered_gamma,
967
        )
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
988
            If :attr:`return_layernorm_output=False`, then this would be None.
989
        """
990

991
992
        ln_output = None

993
994
995
996
997
998
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

999
1000
        inputs = inputs.astype(self.dtype)

1001
1002
1003
1004
1005
1006
1007
1008
        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
1009
        normalized_acts = []
1010
1011
1012
        for act in self.activations:
            if not isinstance(act, str):
                return False
1013
            normalized_acts.append(act.lower())
1014
        normalized_acts = tuple(
1015
1016
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
1017

1018
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
1019

1020
1021
1022
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
1023

1024
1025
        # LayerNorm
        if self.enable_layernorm:
1026
            assert self.axis == -1  # Only support axis == -1 at this moment
1027
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1028

1029
1030
            features = inputs.shape[-1]

1031
1032
1033
1034
1035
1036
1037
1038
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
                self.dtype,
1039
                self.weight_dtype,
1040
            )
1041
1042

            if not fuse_layernorm:
1043
1044
1045
1046
1047
1048
1049
1050
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
1065
            return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
1066

1067
1068
        wi_fp8_meta_pkg = None
        wo_fp8_meta_pkg = None
1069
        if FP8Helper.is_fp8_enabled():
1070
1071
            wi_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
            wo_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("1")
1072

1073
        num_activations = len(normalized_acts)
1074
1075
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1076

1077
1078
1079
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
1080
1081
1082
1083
1084
1085
        kernel_1 = nn_partitioning.param_with_axes(
            "wi_kernel",
            kernel_1_init,
            num_activations,
            -2,
            kernel_1_each_shape,
1086
            self.weight_dtype,
1087
1088
            axes=self.kernel_axes_1,
        )
1089
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
1090
        kernel_1 = kernel_1.astype(self.dtype)
1091
1092
1093
1094
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
1095
1096
1097
1098
        kernel_2 = nn_partitioning.param_with_axes(
            "wo_kernel",
            self.kernel_init,
            kernel_2_param_shape,
1099
            self.weight_dtype,
1100
1101
            axes=self.kernel_axes_2,
        )
1102
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
1103
        kernel_2 = kernel_2.astype(self.dtype)
1104
        contract_ind = tuple(range(0, len(axis)))
1105

1106
1107
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1108

1109
        if use_fused_layernorm_mlp:
1110
            assert self.axis == -1  # Only support axis = =-1 at this moment
1111

1112
1113
            if self.use_bias:
                bias_1_shape = intermediate_dim
1114
                bias_1 = nn_partitioning.param_with_axes(
1115
1116
1117
1118
1119
                    "wi_bias",
                    self.bias_init,
                    bias_1_shape,
                    self.weight_dtype,
                    axes=self.bias_axes_1,
1120
                )
1121
1122
1123
                bias_1 = bias_1.astype(self.dtype)

                bias_2_shape = (hidden_size,)
1124
                bias_2 = nn_partitioning.param_with_axes(
1125
1126
1127
1128
1129
                    "wo_bias",
                    self.bias_init,
                    bias_2_shape,
                    self.weight_dtype,
                    axes=self.bias_axes_2,
1130
                )
1131
1132
                bias_2 = bias_2.astype(self.dtype)
            else:
1133
1134
                bias_1 = None
                bias_2 = None
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
            out = fused_layernorm_fp8_mlp(
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                [wi_fp8_meta_pkg, wo_fp8_meta_pkg],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
                use_bias=self.use_bias,
            )

        else:  # not use_fused_ln_geglu_mlp
1156
1157
            # DenseGeneral 1
            if fuse_layernorm:
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
                x = layernorm_fp8_dot(
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
                    wi_fp8_meta_pkg,
                    self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
                )
1170
            else:
1171
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1172
1173
1174
                x = type_safe_dot_general(
                    y, kernel_1, fp8_meta_pkg=wi_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
                )
1175

1176
            if self.enable_low_rank_adaptation:
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
                wi_lora_a_kernel_shape = (
                    *kernel_1_shape[: len(axis)],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_shape = (
                    kernel_1_each_shape[0],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_each_shape = (
                    kernel_1_each_shape[0],
                    self.low_rank_adaptation_dim,
                )
1191
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
1192
1193
1194
1195
1196
1197
                wi_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_a_kernel",
                    kernel_1_init,
                    num_activations,
                    -2,
                    wi_lora_a_kernel_init_each_shape,
1198
                    self.weight_dtype,
1199
1200
                    axes=wi_lora_a_kernel_axes,
                )
1201
1202
1203
                wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
                wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)

1204
1205
1206
1207
1208
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1209
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1210
1211
1212
1213
                wi_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_b_kernel",
                    nn.initializers.zeros,
                    wi_lora_b_kernel_shape,
1214
                    self.weight_dtype,
1215
1216
                    axes=wi_lora_b_kernel_axes,
                )
1217
1218
                wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)

1219
1220
1221
1222
1223
1224
1225
1226
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
                    intermediate_dim,
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1227

1228
            bias_1 = None
1229
            if self.use_bias:
1230
                bias_1 = nn_partitioning.param_with_axes(
1231
1232
1233
1234
1235
                    "wi_bias",
                    self.bias_init,
                    intermediate_dim,
                    self.weight_dtype,
                    axes=self.bias_axes_1,
1236
                )
1237
                bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
1238
                bias_1 = bias_1.astype(self.dtype)
1239
                x += jnp.reshape(bias_1, bias_1_shape)
1240

1241
            x = checkpoint_name(x, ffn1_ckpt_name)
1242
            if is_act_implemented:
1243
                z = activation_lu(x, normalized_acts)
1244
            else:
1245
                activations = []
1246
                x = jnp.split(x, num_activations, axis=-2)
1247
                for idx, act_fn in enumerate(normalized_acts):
1248
1249
1250
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
1251
1252
                # Remove act axis
                z = jnp.reshape(z, (*z.shape[:-2], -1))
1253
            z = z.astype(self.dtype)
1254

1255
1256
1257
1258
1259
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1260

1261
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
1262
            z = z.astype(self.dtype)
1263

1264
            # DenseGeneral 2
1265
1266
1267
            out = type_safe_dot_general(
                z, kernel_2, fp8_meta_pkg=wo_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
1268

1269
1270
1271
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1272
1273
1274
1275
                wo_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_a_kernel",
                    self.kernel_init,
                    wo_lora_a_kernel_shape,
1276
                    self.weight_dtype,
1277
1278
                    axes=wo_lora_a_kernel_axes,
                )
1279
1280
1281
1282
                wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1283
1284
1285
1286
                wo_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_b_kernel",
                    nn.initializers.zeros,
                    wo_lora_b_kernel_shape,
1287
                    self.weight_dtype,
1288
1289
                    axes=wo_lora_b_kernel_axes,
                )
1290
1291
                wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)

1292
1293
1294
1295
1296
1297
1298
1299
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1300

1301
            bias_2 = None
1302
            if self.use_bias:
1303
                bias_2 = nn_partitioning.param_with_axes(
1304
1305
1306
1307
1308
                    "wo_bias",
                    self.bias_init,
                    (hidden_size,),
                    self.weight_dtype,
                    axes=self.bias_axes_2,
1309
                )
1310
1311
                bias_2 = bias_2.astype(self.dtype)
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1312

1313
            out = checkpoint_name(out, ffn2_ckpt_name)
1314

1315
        return out, ln_output  # Output, layner_norm_output