module.py 50.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
18
from jax.ad_checkpoint import checkpoint_name
19

20
21
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
22
23
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
24
from ..layernorm_mlp import fused_layernorm_fp8_mlp, activation_lu
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
from ..cpp_extensions import is_softmax_kernel_available
28
29
30
31
32

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
33
34
35
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
36
37
38
39
40
41
42
43
44
45
46
47
48
49
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


50
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
51
52
53
54
55
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
56
57
58
    return nn.initializers.zeros


59
60
61
62
63
64
def _create_layernorm_parameters(
    layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype
):
    scale = nn_partitioning.param_with_axes(
        "scale", scale_init, shape, jnp.float32, axes=scale_axes
    )
65
66
67
    scale = jnp.asarray(scale, dtype)

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
68
69
70
71
    if layernorm_type == "layernorm":
        bias = nn_partitioning.param_with_axes(
            "ln_bias", bias_init, shape, jnp.float32, axes=bias_axes
        )
72
73
        bias = jnp.asarray(bias, dtype)
    else:
74
        assert layernorm_type == "rmsnorm"
75
76
77
78
79
80
81
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
82
    if fn_or_string == "linear":
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
97
98
99
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
100
101
102
103
104
105
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


106
107
108
109
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
110
    hidden_in_names = "ijklm"[: len(axis)]
111
    assert len(features) <= 5
112
113
    hidden_out_names = "nopqr"[: len(features)]
    rank_name = "s"
114
115
116
117
118
119
120
121
122

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
123
124
125
126
    final_einsum_express = (
        f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}"
        f"->{output_einsum_express}"
    )
127
128
129
130
131
132

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


133
class Softmax(nn.Module):  # pylint: disable=too-few-public-methods
134
135
    r"""
    Applies softmax over a mini-batch of inputs.
136
137
138
139
140
141
142
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
143
144
145
146

    Parameters
    ----------
    scale_factor : float, default = 1.0
147
148
149
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

164
165
166
        if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
            self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype
        ):
167
168
169
170
171
172
173
174

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

175
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
176
177
178
        else:
            attention_bias = None
            if mask is not None:
179
180
181
182
183
                attention_bias = lax.select(
                    mask > 0,
                    jnp.full(mask.shape, -1e10).astype(dtype),
                    jnp.full(mask.shape, 0.0).astype(dtype),
                )
184
185
186
187
188
189
190
191
192

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
193
194
195
            if is_softmax_kernel_available(
                SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype
            ):
196
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
197
            else:
198
                outputs = jax_nn.softmax(logits * self.scale_factor)
199
200
201
202

        return outputs


203
class LayerNorm(nn.Module):  # pylint: disable=too-few-public-methods
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
233
        A value added to the denominator of layer normalization for numerical stability.
234
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
235
        Indicate the type of layer normalization.
236
237
238
239
240
241
242
243
244
245
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
246
        Used for initializing scale factors :math:`\gamma`.
247
248
249
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
250
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
251
    scale_axes : Tuple[str, ...], default = ('embed', )
252
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
253
    bias_init : Initializer, default = flax.linen.initializers.zeros
254
255
256
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
257
258
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
259
        only used when :attr:`layernorm_type='layernorm'`.
260
261
262
263
264

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
        the data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
265
    transpose_batch_sequence : bool, default = False
266
267
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
268
269
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
270

271
    epsilon: float = 1e-6
272
    layernorm_type: str = "layernorm"
273
274
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
275
    scale_axes: Tuple[str, ...] = ("embed",)
276
    bias_init: Initializer = nn.initializers.zeros
277
    bias_axes: Tuple[str, ...] = ("embed",)
278
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
279
    transpose_batch_sequence: bool = False
280

281
    def __post_init__(self):
282
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
283
284
            self.scale_init, self.zero_centered_gamma
        )
285
286
        super().__post_init__()

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """

        features = x.shape[-1]
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        scale, ln_bias = _create_layernorm_parameters(
            self.layernorm_type,
            (features,),
            self.scale_init,
            self.scale_axes,
            self.bias_init,
            self.bias_axes,
            self.dtype,
        )
        return layernorm(
            x,
            scale,
            ln_bias,
            layernorm_type=self.layernorm_type,
            zero_centered_gamma=self.zero_centered_gamma,
            epsilon=self.epsilon,
        )


class TransformerEngineBase(nn.Module):  # pylint: disable=too-few-public-methods
324
325
326
327
328
    """
    Base class of transformer engine
    """

    @staticmethod
329
    def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage:
330
        """
331
        Generate a set of FP8 meta for a GEMM.
332
333
        """

334
335
336
337
338
        input_name_post_fix = f"_i_{postfix}"
        weight_name_post_fix = f"_w_{postfix}"
        grad_name_post_fix = f"_g_{postfix}"

        def generate_a_set(target_postfix):
339
340
341
342
343
344
345
346
            amax = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
                jnp.zeros,
                (FP8Helper.AMAX_HISTORY_LEN,),
                jnp.float32,
                axes=(None,),
            )
347
348
349
350

            scale = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
351
352
                jnp.ones,
                (1,),
353
                jnp.float32,
354
355
                axes=(None,),
            )
356
357
358
359
360
361

            return amax.value, scale.value

        input_amax, input_scale = generate_a_set(input_name_post_fix)
        weight_amax, weight_scale = generate_a_set(weight_name_post_fix)
        grad_amax, grad_scale = generate_a_set(grad_name_post_fix)
362

363
364
365
        return FP8MetaPackage(
            input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale
        )
366
367
368


class DenseGeneral(TransformerEngineBase):
369
370
    r"""
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`.
371
372
373
374

    Parameters
    ----------
    features : Union[Iterable[int], int]
375
        The hidden size of each output sample.
376
377
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
378
379
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
380
    kernel_axes : Tuple[str, ...], default = ()
381
        The name of axes used to shard the weights with a corresponding mesh.
382
    use_bias: bool, default = False
383
384
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
385
    bias_init: Initializer, default = flax.linen.initializers.zeros
386
387
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
388
    bias_axes: Tuple[str, ...], default = ()
389
390
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
391
392
393
394
395
396
397
398
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
399
    axis:  Union[Iterable[int], int], default = -1
400
        An integer tuple with axes to apply the transformation on.
401
402
403
404

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
405
        The data type used to allocate the initial parameters.
406
    transpose_batch_sequence : bool, default = True
407
408
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
409
410
411
412
413
414
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
415
    use_bias: bool = True
416
417
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
418
419
420
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
421
422
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
423
    transpose_batch_sequence: bool = False
424
425
426

    def __post_init__(self):
        if self.kernel_init is None:
427
            self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
445

446
447
448
449
450
451
452
453
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
454
455
456
        kernel = nn_partitioning.param_with_axes(
            "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
        )
457
458
459
460

        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
461
462
463
            bias = nn_partitioning.param_with_axes(
                "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes
            )
464
            bias = bias.astype(self.dtype)
465
466
467
468
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
469
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
470
        if FP8Helper.is_fp8_enabled():
471
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
472

473
474
475
        y = type_safe_dot_general(
            inputs, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
        )
476

477
        if self.enable_low_rank_adaptation:
478
479
480
481
482
483
484
485
486
487
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
488
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
489
490
491
492
493
494
495
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
                jnp.float32,
                axes=lora_a_kernel_axes,
            )
496
497
498
499
500
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
501
502
503
504
505
506
507
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
                jnp.float32,
                axes=lora_b_kernel_axes,
            )
508
509
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

510
511
512
            y += _apply_low_rank_adaptation(
                inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
513

514
        if bias is not None:
515
516
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
517
518
519
520
521
522
523
524
525
526
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
527
        The hidden size of each output sample.
528
    enable_layernorm: bool, default = True
529
        Indicate whether to enable layer normalization before linear transformation.
530
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
531
        Indicate the type of layer normalization.
532
    epsilon : float, default = 1e-6
533
        A value added to the denominator of layer normalization for numerical stability.
534
535
536
537
538
539
540
541
542
543
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
544
        Used for initializing scale factors :math:`\gamma`.
545
546
547
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
548
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
549
    scale_axes : Tuple[str, ...], default = ('embed', )
550
551
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
552
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
553
554
555
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
556
557
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
558
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
559
560
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
561
562
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
563
    kernel_axes : Tuple[str, ...], default = ()
564
        The name of axes used to shard the weights with a corresponding mesh.
565
    use_bias: bool, default = False
566
567
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
568
    bias_init: Initializer, default = flax.linen.initializers.zeros
569
570
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
571
    bias_axes: Tuple[str, ...], default = ()
572
573
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
574
    return_layernorm_output: bool, default = True
575
        Indicate whether to return the output of layer normalization.
576
        If set False, return None as the second tensor in outputs.
577
578
579
580
581
582
583
584
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
585
    axis:  Union[Iterable[int], int], default = -1
586
        An integer tuple with axes to apply the transformation on.
587
588
589
590
591
592
593
594
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
595
596
597
598

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
599
        The data type used to allocate the initial parameters.
600
    transpose_batch_sequence : bool, default = True
601
602
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
603
604
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
605
        The factor to scale the output from `DenseGeneral`. It should be a float
606
607
608
609
610
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
611
    layernorm_type: str = "layernorm"
612
    epsilon: float = 1e-6
613
614
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
615
    scale_axes: Tuple[str, ...] = ("embed",)
616
    ln_bias_init: Initializer = nn.initializers.zeros
617
    ln_bias_axes: Tuple[str, ...] = ("embed",)
618
619
620
621
622
623
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
624
625
626
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
627
628
629
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
630
631
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
632
633
634
635
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
636
            self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
637
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
638
639
            self.scale_init, self.zero_centered_gamma
        )
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
658
            If :attr:`return_layernorm_output=False`, then this would be None.
659
        """
660

661
662
        ln_output = None

663
664
665
666
667
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )
668
669

        if self.enable_layernorm:
670
671
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)

672
            assert self.axis == -1  # Only support axis = =-1 at this moment
673
674
            features = inputs.shape[-1]

675
676
677
678
679
680
681
682
683
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
                self.dtype,
            )
684
685

            if not fuse_layernorm:
686
687
688
689
690
691
692
693
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
711
712
713
        kernel = nn_partitioning.param_with_axes(
            "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
        )
714
715
716
717
718

        kernel = jnp.reshape(kernel, kernel_shape)

        contract_ind = tuple(range(0, len(axis)))

719
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
720
        if FP8Helper.is_fp8_enabled():
721
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
722
723

        if fuse_layernorm:
724
725
726
727
728
729
730
731
732
733
734
735
            z = layernorm_fp8_dot(
                y,
                kernel,
                scale,
                ln_bias,
                fp8_meta_pkg,
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_input_axes=self.dot_input_axes,
            )
736
        else:
737
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
738
739
740
            z = type_safe_dot_general(
                y, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
741

742
        if self.enable_low_rank_adaptation:
743
744
745
746
747
748
749
750
751
752
            lora_a_kernel_shape = (
                *kernel_shape[: len(axis)],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
            lora_a_kernel_init_shape = (
                kernel_param_shape[0],
                *features[:-1],
                self.low_rank_adaptation_dim,
            )
753
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
754
755
756
757
758
759
760
            lora_a_kernel = nn_partitioning.param_with_axes(
                "lora_a_kernel",
                self.kernel_init,
                lora_a_kernel_init_shape,
                jnp.float32,
                axes=lora_a_kernel_axes,
            )
761
762
763
764
765
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
766
767
768
769
770
771
772
            lora_b_kernel = nn_partitioning.param_with_axes(
                "lora_b_kernel",
                nn.initializers.zeros,
                lora_b_kernel_shape,
                jnp.float32,
                axes=lora_b_kernel_axes,
            )
773
774
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

775
776
777
            z += _apply_low_rank_adaptation(
                y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
            )
778

779
780
        bias = None
        if self.use_bias:
781
782
783
            bias = nn_partitioning.param_with_axes(
                "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes
            )
784
            bias = bias.astype(self.dtype)
785
786

        if bias is not None:
787
788
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
789
790
791
792

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

793
        return z, ln_output  # dense_output, layer_norm_output
794
795
796
797
798
799
800
801
802
803


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
804
        Intermediate size to which input samples are projected.
805
    enable_layernorm: bool, default = True
806
        Indicate whether to enable layer normalization before linear transformation.
807
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
808
        Indicate the type of layer normalization.
809
    epsilon : float, default = 1e-6
810
        A value added to the denominator of layer normalization for numerical stability.
811
812
813
814
815
816
817
818
819
820
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
821
        Used for initializing scale factors :math:`\gamma`.
822
823
824
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
825
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
826
    scale_axes : Tuple[str, ...], default = ('embed', )
827
828
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
829
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
830
831
832
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
833
834
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
835
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
836
837
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
838
839
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
840
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
841
        The name of axes used to shard the weights with a corresponding mesh for
842
843
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
844
        The name of axes used to shard the weights with a corresponding mesh for
845
846
        the weight of the second linear transformations.
    use_bias: bool, default = False
847
848
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
849
    bias_init: Initializer, default = flax.linen.initializers.zeros
850
851
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
852
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
853
        The name of axes used to shard bias with a corresponding mesh  for
854
        the weight of the first linear transformations.
855
        Only used when :attr:`use_bias=True`.
856
    bias_axes_2: Tuple[str, ...], default = ('embed',)
857
        The name of axes used to shard bias with a corresponding mesh  for
858
        the weight of the second linear transformations.
859
        Only used when :attr:`use_bias=True`.
860
    return_layernorm_output: bool, default = True
861
        Indicate whether to return the output of layer normalization.
862
863
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
864
        The sequence of activation functions to apply after the first linear transformation.
865
        Each activation has its own transformation layer.
866
867
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
868
    intermediate_dropout_rate: float, default = 0.1
869
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
870
871
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
872
873
874
875
876
877
878
879
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
880
    axis:  Union[Iterable[int], int], default = -1
881
        An integer tuple with axes to apply the transformation on.
882
883
884
885
886
887
888
889
890
891
892
893
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
894
895
896
897

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
898
        The data type used to allocate the initial parameters.
899
    transpose_batch_sequence : bool, default = True
900
901
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
902
903
904
905
906
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
907
    layernorm_type: str = "layernorm"
908
    epsilon: float = 1e-6
909
910
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
911
    scale_axes: Tuple[str, ...] = ("embed",)
912
    ln_bias_init: Initializer = nn.initializers.zeros
913
    ln_bias_axes: Tuple[str, ...] = ("embed",)
914
    kernel_init: Initializer = None
915
916
    kernel_axes_1: Tuple[str, ...] = ("embed", "act", "mlp")
    kernel_axes_2: Tuple[str, ...] = ("mlp", "embed")
917
918
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
919
920
    bias_axes_1: Tuple[str, ...] = ("act", "mlp")
    bias_axes_2: Tuple[str, ...] = ("embed",)
921
    return_layernorm_output: bool = True
922
923
    activations: Sequence[Union[str, Callable]] = ("relu",)
    intermediate_dropout_rng_name: str = "dropout"
924
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
925
    intermediate_hidden_dropout_dims: Sequence[int] = ()
926
927
928
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
929
930
931
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
932
933
934
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
935
936
937

    def __post_init__(self):
        if self.kernel_init is None:
938
            self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
939
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
940
941
            self.scale_init, self.zero_centered_gamma
        )
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
962
            If :attr:`return_layernorm_output=False`, then this would be None.
963
        """
964

965
966
        ln_output = None

967
968
969
970
971
972
973
974
975
976
977
978
979
980
        fuse_layernorm = (
            FP8Helper.is_fp8_enabled()
            and not self.return_layernorm_output
            and self.enable_layernorm
        )

        gated_act_pool = [
            ("gelu", "linear"),
            ("silu", "linear"),
            ("relu", "linear"),
            ("quick_gelu", "linear"),
            ("squared_relu", "linear"),
        ]
        act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
981
        normalized_acts = []
982
983
984
        for act in self.activations:
            if not isinstance(act, str):
                return False
985
            normalized_acts.append(act.lower())
986
        normalized_acts = tuple(
987
988
            reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
        )
989

990
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
991

992
993
994
        use_fused_layernorm_mlp = (
            fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
        )
995

996
997
        # LayerNorm
        if self.enable_layernorm:
998
            assert self.axis == -1  # Only support axis == -1 at this moment
999
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
1000

1001
1002
            features = inputs.shape[-1]

1003
1004
1005
1006
1007
1008
1009
1010
1011
            scale, ln_bias = _create_layernorm_parameters(
                self.layernorm_type,
                (features,),
                self.scale_init,
                self.scale_axes,
                self.ln_bias_init,
                self.ln_bias_axes,
                self.dtype,
            )
1012
1013

            if not fuse_layernorm:
1014
1015
1016
1017
1018
1019
1020
1021
                y = layernorm(
                    inputs,
                    scale,
                    ln_bias,
                    layernorm_type=self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                )
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
            return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)

1038
1039
        wi_fp8_meta_pkg = None
        wo_fp8_meta_pkg = None
1040
        if FP8Helper.is_fp8_enabled():
1041
1042
            wi_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
            wo_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("1")
1043

1044
        num_activations = len(normalized_acts)
1045
1046
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
1047

1048
1049
1050
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
1051
1052
1053
1054
1055
1056
1057
1058
1059
        kernel_1 = nn_partitioning.param_with_axes(
            "wi_kernel",
            kernel_1_init,
            num_activations,
            -2,
            kernel_1_each_shape,
            jnp.float32,
            axes=self.kernel_axes_1,
        )
1060
1061
1062
1063
1064
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
1065
1066
1067
1068
1069
1070
1071
        kernel_2 = nn_partitioning.param_with_axes(
            "wo_kernel",
            self.kernel_init,
            kernel_2_param_shape,
            jnp.float32,
            axes=self.kernel_axes_2,
        )
1072
1073
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
        contract_ind = tuple(range(0, len(axis)))
1074

1075
1076
        ffn1_ckpt_name = "ffn1"
        ffn2_ckpt_name = "ffn2"
1077

1078
        if use_fused_layernorm_mlp:
1079
            assert self.axis == -1  # Only support axis = =-1 at this moment
1080

1081
1082
            if self.use_bias:
                bias_1_shape = intermediate_dim
1083
1084
1085
                bias_1 = nn_partitioning.param_with_axes(
                    "wi_bias", self.bias_init, bias_1_shape, jnp.float32, axes=self.bias_axes_1
                )
1086
1087
1088
                bias_1 = bias_1.astype(self.dtype)

                bias_2_shape = (hidden_size,)
1089
1090
1091
                bias_2 = nn_partitioning.param_with_axes(
                    "wo_bias", self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2
                )
1092
1093
                bias_2 = bias_2.astype(self.dtype)
            else:
1094
1095
                bias_1 = None
                bias_2 = None
1096

1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
            out = fused_layernorm_fp8_mlp(
                y,
                scale,
                ln_bias,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                [wi_fp8_meta_pkg, wo_fp8_meta_pkg],
                self.layernorm_type,
                zero_centered_gamma=self.zero_centered_gamma,
                epsilon=self.epsilon,
                layernorm_input_axes=self.layernorm_input_axes,
                dot_1_input_axes=self.dot_1_input_axes,
                dot_2_input_axes=self.dot_2_input_axes,
                ffn1_ckpt_name=ffn1_ckpt_name,
                ffn2_ckpt_name=ffn2_ckpt_name,
                activation_type=normalized_acts,
                use_bias=self.use_bias,
            )

        else:  # not use_fused_ln_geglu_mlp
1117
1118
            # DenseGeneral 1
            if fuse_layernorm:
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
                x = layernorm_fp8_dot(
                    y,
                    kernel_1,
                    scale,
                    ln_bias,
                    wi_fp8_meta_pkg,
                    self.layernorm_type,
                    zero_centered_gamma=self.zero_centered_gamma,
                    epsilon=self.epsilon,
                    layernorm_input_axes=self.layernorm_input_axes,
                    dot_input_axes=self.dot_1_input_axes,
                )
1131
            else:
1132
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1133
1134
1135
                x = type_safe_dot_general(
                    y, kernel_1, fp8_meta_pkg=wi_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
                )
1136

1137
            if self.enable_low_rank_adaptation:
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
                wi_lora_a_kernel_shape = (
                    *kernel_1_shape[: len(axis)],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_shape = (
                    kernel_1_each_shape[0],
                    num_activations,
                    self.low_rank_adaptation_dim,
                )
                wi_lora_a_kernel_init_each_shape = (
                    kernel_1_each_shape[0],
                    self.low_rank_adaptation_dim,
                )
1152
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
1153
1154
1155
1156
1157
1158
1159
1160
1161
                wi_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_a_kernel",
                    kernel_1_init,
                    num_activations,
                    -2,
                    wi_lora_a_kernel_init_each_shape,
                    jnp.float32,
                    axes=wi_lora_a_kernel_axes,
                )
1162
1163
1164
                wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
                wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)

1165
1166
1167
1168
1169
                wi_lora_b_kernel_shape = (
                    num_activations,
                    self.low_rank_adaptation_dim,
                    self.intermediate_dim,
                )
1170
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
1171
1172
1173
1174
1175
1176
1177
                wi_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wi_lora_b_kernel",
                    nn.initializers.zeros,
                    wi_lora_b_kernel_shape,
                    jnp.float32,
                    axes=wi_lora_b_kernel_axes,
                )
1178
1179
                wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)

1180
1181
1182
1183
1184
1185
1186
1187
                x += _apply_low_rank_adaptation(
                    y,
                    axis,
                    intermediate_dim,
                    wi_lora_a_kernel,
                    wi_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1188

1189
            bias_1 = None
1190
            if self.use_bias:
1191
1192
1193
                bias_1 = nn_partitioning.param_with_axes(
                    "wi_bias", self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1
                )
1194
1195
1196
                bias_1 = bias_1.astype(self.dtype)
                bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
                x += jnp.reshape(bias_1, bias_1_shape)
1197

1198
            x = checkpoint_name(x, ffn1_ckpt_name)
1199
            if is_act_implemented:
1200
                z = activation_lu(x, normalized_acts)
1201
            else:
1202
                activations = []
1203
                x = jnp.split(x, num_activations, axis=-2)
1204
                for idx, act_fn in enumerate(normalized_acts):
1205
1206
1207
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
1208
1209
                # Remove act axis
                z = jnp.reshape(z, (*z.shape[:-2], -1))
1210

1211
1212
1213
1214
1215
            z = nn.Dropout(
                rate=self.intermediate_dropout_rate,
                broadcast_dims=self.intermediate_hidden_dropout_dims,
                rng_collection=self.intermediate_dropout_rng_name,
            )(z, deterministic=deterministic)
1216

1217
1218
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)

1219
            # DenseGeneral 2
1220
1221
1222
            out = type_safe_dot_general(
                z, kernel_2, fp8_meta_pkg=wo_fp8_meta_pkg, contracting_dims=(axis, contract_ind)
            )
1223

1224
1225
1226
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
1227
1228
1229
1230
1231
1232
1233
                wo_lora_a_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_a_kernel",
                    self.kernel_init,
                    wo_lora_a_kernel_shape,
                    jnp.float32,
                    axes=wo_lora_a_kernel_axes,
                )
1234
1235
1236
1237
                wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
1238
1239
1240
1241
1242
1243
1244
                wo_lora_b_kernel = nn_partitioning.param_with_axes(
                    "wo_lora_b_kernel",
                    nn.initializers.zeros,
                    wo_lora_b_kernel_shape,
                    jnp.float32,
                    axes=wo_lora_b_kernel_axes,
                )
1245
1246
                wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)

1247
1248
1249
1250
1251
1252
1253
1254
                out += _apply_low_rank_adaptation(
                    z,
                    axis,
                    hidden_size_tuple,
                    wo_lora_a_kernel,
                    wo_lora_b_kernel,
                    self.low_rank_adaptation_alpha,
                )
1255

1256
            bias_2 = None
1257
            if self.use_bias:
1258
1259
1260
                bias_2 = nn_partitioning.param_with_axes(
                    "wo_bias", self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2
                )
1261
1262
                bias_2 = bias_2.astype(self.dtype)
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1263

1264
            out = checkpoint_name(out, ffn2_ckpt_name)
1265

1266
        return out, ln_output  # Output, layner_norm_output