module.py 54.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
18
from jax.ad_checkpoint import checkpoint_name
19

20
21
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
22
23
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
24
from ..layernorm_mlp import fused_layernorm_fp8_mlp, activation_lu
25
from ..softmax import softmax, SoftmaxType
26
from ..sharding import with_sharding_constraint_by_logical_axes
27
from ..cpp_extensions import is_softmax_kernel_available
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


49
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
50
51
52
53
54
    if original_init is not None:
        return original_init

    if not zero_centered_gamma:
        return nn.initializers.ones
55
56
57
    return nn.initializers.zeros


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
                                 bias_axes, dtype):
    scale = nn_partitioning.param_with_axes('scale',
                                            scale_init,
                                            shape,
                                            jnp.float32,
                                            axes=scale_axes)
    scale = jnp.asarray(scale, dtype)

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'layernorm':
        bias = nn_partitioning.param_with_axes('ln_bias',
                                               bias_init,
                                               shape,
                                               jnp.float32,
                                               axes=bias_axes)
        bias = jnp.asarray(bias, dtype)
    else:
        assert layernorm_type == 'rmsnorm'
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
    """Low Rank Adaptation Implementation"""

    assert len(axis) <= 5
    hidden_in_names = 'ijklm'[:len(axis)]
    assert len(features) <= 5
    hidden_out_names = 'nopqr'[:len(features)]
    rank_name = 's'

    assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
    rank = lora_a_kernel.shape[-1]
    scaling = alpha / rank if alpha is not None else 1.0

    x_einsum_express = f"...{hidden_in_names}"
    lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
    lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
    output_einsum_express = f"...{hidden_out_names}"
    final_einsum_express = f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" \
                           f"->{output_einsum_express}"

    output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
    output = output * scaling
    return output


132
class Softmax(nn.Module):    # pylint: disable=too-few-public-methods
133
134
    r"""
    Applies softmax over a mini-batch of inputs.
135
136
137
138
139
140
141
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
142
143
144
145

    Parameters
    ----------
    scale_factor : float, default = 1.0
146
147
148
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

        if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
                self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)):

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

173
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        else:
            attention_bias = None
            if mask is not None:
                attention_bias = lax.select(mask > 0,
                                            jnp.full(mask.shape, -1e10).astype(dtype),
                                            jnp.full(mask.shape, 0.).astype(dtype))

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
            if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
                                           dtype):
191
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
192
            else:
193
                outputs = jax_nn.softmax(logits * self.scale_factor)
194
195
196
197

        return outputs


198
class LayerNorm(nn.Module):    # pylint: disable=too-few-public-methods
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
228
        A value added to the denominator of layer normalization for numerical stability.
229
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
230
        Indicate the type of layer normalization.
231
232
233
234
235
236
237
238
239
240
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
241
        Used for initializing scale factors :math:`\gamma`.
242
243
244
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
245
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
246
    scale_axes : Tuple[str, ...], default = ('embed', )
247
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
248
    bias_init : Initializer, default = flax.linen.initializers.zeros
249
250
251
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
252
253
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
254
        only used when :attr:`layernorm_type='layernorm'`.
255
256
257
258
259

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
        the data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
260
    transpose_batch_sequence : bool, default = False
261
262
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
263
264
265
266
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
    epsilon: float = 1e-6
    layernorm_type: str = 'layernorm'
267
268
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
269
270
271
272
    scale_axes: Tuple[str, ...] = ('embed',)
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ('embed',)
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
273
    transpose_batch_sequence: bool = False
274

275
    def __post_init__(self):
276
277
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
278
279
        super().__post_init__()

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """

        features = x.shape[-1]
        scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                      self.scale_init, self.scale_axes,
                                                      self.bias_init, self.bias_axes, self.dtype)
        return layernorm(x,
                         scale,
                         ln_bias,
303
304
                         layernorm_type=self.layernorm_type,
                         zero_centered_gamma=self.zero_centered_gamma,
305
                         epsilon=self.epsilon)
306
307


308
class TransformerEngineBase(nn.Module):    # pylint: disable=too-few-public-methods
309
310
311
312
313
    """
    Base class of transformer engine
    """

    @staticmethod
314
    def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage:
315
        """
316
        Generate a set of FP8 meta for a GEMM.
317
318
        """

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        input_name_post_fix = f"_i_{postfix}"
        weight_name_post_fix = f"_w_{postfix}"
        grad_name_post_fix = f"_g_{postfix}"

        def generate_a_set(target_postfix):
            amax = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                      f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
                                                      jnp.zeros, (FP8Helper.AMAX_HISTORY_LEN,),
                                                      jnp.float32,
                                                      axes=(None,))

            scale = nn_partitioning.variable_with_axes(
                FP8Helper.FP8_COLLECTION_NAME,
                f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
                jnp.ones, (1,),
                jnp.float32,
                axes=(None,))

            return amax.value, scale.value

        input_amax, input_scale = generate_a_set(input_name_post_fix)
        weight_amax, weight_scale = generate_a_set(weight_name_post_fix)
        grad_amax, grad_scale = generate_a_set(grad_name_post_fix)
342

343
344
        return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax,
                              grad_scale)
345
346
347
348
349
350
351
352
353


class DenseGeneral(TransformerEngineBase):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    Parameters
    ----------
    features : Union[Iterable[int], int]
354
        The hidden size of each output sample.
355
356
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
357
358
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
359
    kernel_axes : Tuple[str, ...], default = ()
360
        The name of axes used to shard the weights with a corresponding mesh.
361
    use_bias: bool, default = False
362
363
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
364
    bias_init: Initializer, default = flax.linen.initializers.zeros
365
366
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
367
    bias_axes: Tuple[str, ...], default = ()
368
369
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
370
371
372
373
374
375
376
377
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
378
    axis:  Union[Iterable[int], int], default = -1
379
        An integer tuple with axes to apply the transformation on.
380
381
382
383

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
384
        The data type used to allocate the initial parameters.
385
    transpose_batch_sequence : bool, default = True
386
387
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
388
389
390
391
392
393
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
394
    use_bias: bool = True
395
396
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
397
398
399
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
400
401
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
402
    transpose_batch_sequence: bool = False
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
443
444
                                                   self.bias_init,
                                                   features,
445
                                                   jnp.float32,
446
                                                   axes=self.bias_axes)
447
            bias = bias.astype(self.dtype)
448
449
450
451
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
452
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
453
        if FP8Helper.is_fp8_enabled():
454
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
455
456
457

        y = type_safe_dot_general(inputs,
                                  kernel,
458
                                  fp8_meta_pkg=fp8_meta_pkg,
459
                                  contracting_dims=(axis, contract_ind))
460

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
        if self.enable_low_rank_adaptation:
            lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
                                   self.low_rank_adaptation_dim)
            lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
                                        self.low_rank_adaptation_dim)
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
            lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
                                                            self.kernel_init,
                                                            lora_a_kernel_init_shape,
                                                            jnp.float32,
                                                            axes=lora_a_kernel_axes)
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
            lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
                                                            nn.initializers.zeros,
                                                            lora_b_kernel_shape,
                                                            jnp.float32,
                                                            axes=lora_b_kernel_axes)
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

            y += _apply_low_rank_adaptation(inputs, axis, features, lora_a_kernel, lora_b_kernel,
                                            self.low_rank_adaptation_alpha)

487
        if bias is not None:
488
489
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
490
491
492
493
494
495
496
497
498
499
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
500
        The hidden size of each output sample.
501
    enable_layernorm: bool, default = True
502
        Indicate whether to enable layer normalization before linear transformation.
503
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
504
        Indicate the type of layer normalization.
505
    epsilon : float, default = 1e-6
506
        A value added to the denominator of layer normalization for numerical stability.
507
508
509
510
511
512
513
514
515
516
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
517
        Used for initializing scale factors :math:`\gamma`.
518
519
520
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
521
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
522
    scale_axes : Tuple[str, ...], default = ('embed', )
523
524
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
525
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
526
527
528
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
529
530
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
531
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
532
533
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
534
535
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
536
    kernel_axes : Tuple[str, ...], default = ()
537
        The name of axes used to shard the weights with a corresponding mesh.
538
    use_bias: bool, default = False
539
540
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
541
    bias_init: Initializer, default = flax.linen.initializers.zeros
542
543
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
544
    bias_axes: Tuple[str, ...], default = ()
545
546
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
547
    return_layernorm_output: bool, default = True
548
        Indicate whether to return the output of layer normalization.
549
        If set False, return None as the second tensor in outputs.
550
551
552
553
554
555
556
557
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
558
    axis:  Union[Iterable[int], int], default = -1
559
        An integer tuple with axes to apply the transformation on.
560
561
562
563
564
565
566
567
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
568
569
570
571

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
572
        The data type used to allocate the initial parameters.
573
    transpose_batch_sequence : bool, default = True
574
575
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
576
577
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
578
        The factor to scale the output from `DenseGeneral`. It should be a float
579
580
581
582
583
584
585
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
586
587
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
588
589
590
591
592
593
594
595
596
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
597
598
599
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
600
601
602
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
603
604
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
605
606
607
608
609
    depth_scaling: float = None

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
610
611
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
630
            If :attr:`return_layernorm_output=False`, then this would be None.
631
        """
632

633
634
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
635
        fuse_layernorm = FP8Helper.is_fp8_enabled(
636
637
638
        ) and not self.return_layernorm_output and self.enable_layernorm

        if self.enable_layernorm:
639
640
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)

641
            assert self.axis == -1    # Only support axis = =-1 at this moment
642
643
644
645
646
647
648
649
650
651
652
653
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
654
                              zero_centered_gamma=self.zero_centered_gamma,
655
                              epsilon=self.epsilon)
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        contract_ind = tuple(range(0, len(axis)))

683
        fp8_meta_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
684
        if FP8Helper.is_fp8_enabled():
685
            fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
686
687
688
689
690
691

        if fuse_layernorm:
            z = layernorm_fp8_dot(y,
                                  kernel,
                                  scale,
                                  ln_bias,
692
                                  fp8_meta_pkg,
693
694
                                  self.layernorm_type,
                                  zero_centered_gamma=self.zero_centered_gamma,
695
696
697
                                  epsilon=self.epsilon,
                                  layernorm_input_axes=self.layernorm_input_axes,
                                  dot_input_axes=self.dot_input_axes)
698
        else:
699
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
700
701
            z = type_safe_dot_general(y,
                                      kernel,
702
                                      fp8_meta_pkg=fp8_meta_pkg,
703
                                      contracting_dims=(axis, contract_ind))
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
        if self.enable_low_rank_adaptation:
            lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
                                   self.low_rank_adaptation_dim)
            lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
                                        self.low_rank_adaptation_dim)
            lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
            lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
                                                            self.kernel_init,
                                                            lora_a_kernel_init_shape,
                                                            jnp.float32,
                                                            axes=lora_a_kernel_axes)
            lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
            lora_a_kernel = lora_a_kernel.astype(self.dtype)

            lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
            lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
            lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
                                                            nn.initializers.zeros,
                                                            lora_b_kernel_shape,
                                                            jnp.float32,
                                                            axes=lora_b_kernel_axes)
            lora_b_kernel = lora_b_kernel.astype(self.dtype)

            z += _apply_low_rank_adaptation(y, axis, features, lora_a_kernel, lora_b_kernel,
                                            self.low_rank_adaptation_alpha)

731
732
733
        bias = None
        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
734
735
                                                   self.bias_init,
                                                   features,
736
                                                   jnp.float32,
737
                                                   axes=self.bias_axes)
738
            bias = bias.astype(self.dtype)
739
740

        if bias is not None:
741
742
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

        return z, ln_output    # dense_output, layer_norm_output


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
758
        Intermediate size to which input samples are projected.
759
    enable_layernorm: bool, default = True
760
        Indicate whether to enable layer normalization before linear transformation.
761
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
762
        Indicate the type of layer normalization.
763
    epsilon : float, default = 1e-6
764
        A value added to the denominator of layer normalization for numerical stability.
765
766
767
768
769
770
771
772
773
774
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
775
        Used for initializing scale factors :math:`\gamma`.
776
777
778
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
779
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
780
    scale_axes : Tuple[str, ...], default = ('embed', )
781
782
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
783
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
784
785
786
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
787
788
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
789
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
790
791
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
792
793
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
794
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
795
        The name of axes used to shard the weights with a corresponding mesh for
796
797
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
798
        The name of axes used to shard the weights with a corresponding mesh for
799
800
        the weight of the second linear transformations.
    use_bias: bool, default = False
801
802
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
803
    bias_init: Initializer, default = flax.linen.initializers.zeros
804
805
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
806
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
807
        The name of axes used to shard bias with a corresponding mesh  for
808
        the weight of the first linear transformations.
809
        Only used when :attr:`use_bias=True`.
810
    bias_axes_2: Tuple[str, ...], default = ('embed',)
811
        The name of axes used to shard bias with a corresponding mesh  for
812
        the weight of the second linear transformations.
813
        Only used when :attr:`use_bias=True`.
814
    return_layernorm_output: bool, default = True
815
        Indicate whether to return the output of layer normalization.
816
817
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
818
        The sequence of activation functions to apply after the first linear transformation.
819
        Each activation has its own transformation layer.
820
821
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
822
    intermediate_dropout_rate: float, default = 0.1
823
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
824
825
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
826
827
828
829
830
831
832
833
    enable_low_rank_adaptation: bool, default = False
        Indicate whether to enable low rank adaptation for each linear layer.
    low_rank_adaptation_dim: int, default = 32
        The dimension for low rank adaptation, only used when
        :attr:`enable_low_rank_adaptation=True`.
    low_rank_adaptation_alpha: float, default = None
        The alpha for computing the scaling factor of LoRA output.
        :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
834
    axis:  Union[Iterable[int], int], default = -1
835
        An integer tuple with axes to apply the transformation on.
836
837
838
839
840
841
842
843
844
845
846
847
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
848
849
850
851

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
852
        The data type used to allocate the initial parameters.
853
    transpose_batch_sequence : bool, default = True
854
855
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
856
857
858
859
860
861
862
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
863
864
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
865
866
867
868
869
870
871
872
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp')
    kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
873
    bias_axes_1: Tuple[str, ...] = ('act', 'mlp')
874
875
876
    bias_axes_2: Tuple[str, ...] = ('embed',)
    return_layernorm_output: bool = True
    activations: Sequence[Union[str, Callable]] = ('relu',)
877
    intermediate_dropout_rng_name: str = 'dropout'
878
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
879
    intermediate_hidden_dropout_dims: Sequence[int] = ()
880
881
882
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
883
884
885
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
886
887
888
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
889
890
891
892

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
893
894
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
915
            If :attr:`return_layernorm_output=False`, then this would be None.
916
        """
917

918
919
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
920
        fuse_layernorm = FP8Helper.is_fp8_enabled(
921
922
        ) and not self.return_layernorm_output and self.enable_layernorm

923
924
925
        gated_act_pool = [('gelu', 'linear'), ('silu', 'linear'), ('relu', 'linear'),
                          ('quick_gelu', 'linear'), ('squared_relu', 'linear')]
        act_pool = [('gelu',), ('silu',), ('relu',), ('quick_gelu',), ('squared_relu',)]
926
        normalized_acts = []
927
928
929
        for act in self.activations:
            if not isinstance(act, str):
                return False
930
            normalized_acts.append(act.lower())
931
932
        normalized_acts = tuple(
            reversed(normalized_acts) if normalized_acts[0] == 'linear' else normalized_acts)
933

934
        is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
935
936
937

        use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\
                                self.intermediate_dropout_rate < 1e-3
938

939
940
        # LayerNorm
        if self.enable_layernorm:
941
            assert self.axis == -1    # Only support axis == -1 at this moment
942
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
943

944
945
946
947
948
949
950
951
952
953
954
955
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
956
                              zero_centered_gamma=self.zero_centered_gamma,
957
                              epsilon=self.epsilon)
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
            return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)

974
975
        wi_fp8_meta_pkg = None
        wo_fp8_meta_pkg = None
976
        if FP8Helper.is_fp8_enabled():
977
978
            wi_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
            wo_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("1")
979

980
        num_activations = len(normalized_acts)
981
982
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
983

984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
        kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
                                                   kernel_1_init,
                                                   num_activations,
                                                   -2,
                                                   kernel_1_each_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_1)
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
        kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
                                                   self.kernel_init,
                                                   kernel_2_param_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_2)
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
        contract_ind = tuple(range(0, len(axis)))
1006

1007
1008
1009
        ffn1_ckpt_name = 'ffn1'
        ffn2_ckpt_name = 'ffn2'

1010
        if use_fused_layernorm_mlp:
1011
1012
            assert self.axis == -1    # Only support axis = =-1 at this moment

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
            if self.use_bias:
                bias_1_shape = intermediate_dim
                bias_1 = nn_partitioning.param_with_axes('wi_bias',
                                                         self.bias_init,
                                                         bias_1_shape,
                                                         jnp.float32,
                                                         axes=self.bias_axes_1)
                bias_1 = bias_1.astype(self.dtype)

                bias_2_shape = (hidden_size,)
                bias_2 = nn_partitioning.param_with_axes('wo_bias',
                                                         self.bias_init,
                                                         bias_2_shape,
                                                         jnp.float32,
                                                         axes=self.bias_axes_2)
                bias_2 = bias_2.astype(self.dtype)
            else:
1030
1031
                bias_1 = None
                bias_2 = None
1032

1033
            out = fused_layernorm_fp8_mlp(y,
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
                                          scale,
                                          ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
                                          [wi_fp8_meta_pkg, wo_fp8_meta_pkg],
                                          self.layernorm_type,
                                          zero_centered_gamma=self.zero_centered_gamma,
                                          epsilon=self.epsilon,
                                          layernorm_input_axes=self.layernorm_input_axes,
                                          dot_1_input_axes=self.dot_1_input_axes,
                                          dot_2_input_axes=self.dot_2_input_axes,
                                          ffn1_ckpt_name=ffn1_ckpt_name,
                                          ffn2_ckpt_name=ffn2_ckpt_name,
                                          activation_type=normalized_acts,
                                          use_bias=self.use_bias)

1048
        else:    # not use_fused_ln_geglu_mlp
1049
1050
1051
1052
1053
1054
            # DenseGeneral 1
            if fuse_layernorm:
                x = layernorm_fp8_dot(y,
                                      kernel_1,
                                      scale,
                                      ln_bias,
1055
                                      wi_fp8_meta_pkg,
1056
1057
                                      self.layernorm_type,
                                      zero_centered_gamma=self.zero_centered_gamma,
1058
1059
1060
                                      epsilon=self.epsilon,
                                      layernorm_input_axes=self.layernorm_input_axes,
                                      dot_input_axes=self.dot_1_input_axes)
1061
            else:
1062
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
1063
1064
                x = type_safe_dot_general(y,
                                          kernel_1,
1065
                                          fp8_meta_pkg=wi_fp8_meta_pkg,
1066
                                          contracting_dims=(axis, contract_ind))
1067

1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
            if self.enable_low_rank_adaptation:
                wi_lora_a_kernel_shape = (*kernel_1_shape[:len(axis)], num_activations,
                                          self.low_rank_adaptation_dim)
                wi_lora_a_kernel_init_shape = (kernel_1_each_shape[0], num_activations,
                                               self.low_rank_adaptation_dim)
                wi_lora_a_kernel_init_each_shape = (kernel_1_each_shape[0],
                                                    self.low_rank_adaptation_dim)
                wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
                wi_lora_a_kernel = nn_partitioning.param_with_axes('wi_lora_a_kernel',
                                                                   kernel_1_init,
                                                                   num_activations,
                                                                   -2,
                                                                   wi_lora_a_kernel_init_each_shape,
                                                                   jnp.float32,
                                                                   axes=wi_lora_a_kernel_axes)
                wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
                wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)

                wi_lora_b_kernel_shape = (num_activations, self.low_rank_adaptation_dim,
                                          self.intermediate_dim)
                wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
                wi_lora_b_kernel = nn_partitioning.param_with_axes('wi_lora_b_kernel',
                                                                   nn.initializers.zeros,
                                                                   wi_lora_b_kernel_shape,
                                                                   jnp.float32,
                                                                   axes=wi_lora_b_kernel_axes)
                wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)

                x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel,
                                                wi_lora_b_kernel, self.low_rank_adaptation_alpha)

1099
            bias_1 = None
1100
            if self.use_bias:
1101
                bias_1 = nn_partitioning.param_with_axes('wi_bias',
1102
1103
1104
1105
                                                         self.bias_init,
                                                         intermediate_dim,
                                                         jnp.float32,
                                                         axes=self.bias_axes_1)
1106
1107
1108
                bias_1 = bias_1.astype(self.dtype)
                bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
                x += jnp.reshape(bias_1, bias_1_shape)
1109

1110
            x = checkpoint_name(x, ffn1_ckpt_name)
1111
            if is_act_implemented:
1112
                z = activation_lu(x, normalized_acts)
1113
            else:
1114
                activations = []
1115
                x = jnp.split(x, num_activations, axis=-2)
1116
                for idx, act_fn in enumerate(normalized_acts):
1117
1118
1119
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
1120
1121
                # Remove act axis
                z = jnp.reshape(z, (*z.shape[:-2], -1))
1122

Ming-Xu Huang's avatar
Ming-Xu Huang committed
1123
            z = nn.Dropout(rate=self.intermediate_dropout_rate,
1124
1125
                           broadcast_dims=self.intermediate_hidden_dropout_dims,
                           rng_collection=self.intermediate_dropout_rng_name)(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1126
                               z, deterministic=deterministic)
1127

1128
1129
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)

1130
            # DenseGeneral 2
1131
1132
            out = type_safe_dot_general(z,
                                        kernel_2,
1133
                                        fp8_meta_pkg=wo_fp8_meta_pkg,
1134
                                        contracting_dims=(axis, contract_ind))
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
            if self.enable_low_rank_adaptation:
                wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
                wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
                wo_lora_a_kernel = nn_partitioning.param_with_axes('wo_lora_a_kernel',
                                                                   self.kernel_init,
                                                                   wo_lora_a_kernel_shape,
                                                                   jnp.float32,
                                                                   axes=wo_lora_a_kernel_axes)
                wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)

                wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
                wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
                wo_lora_b_kernel = nn_partitioning.param_with_axes('wo_lora_b_kernel',
                                                                   nn.initializers.zeros,
                                                                   wo_lora_b_kernel_shape,
                                                                   jnp.float32,
                                                                   axes=wo_lora_b_kernel_axes)
                wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)

                out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel,
                                                  wo_lora_b_kernel, self.low_rank_adaptation_alpha)

1158
            bias_2 = None
1159
            if self.use_bias:
1160
                bias_2 = nn_partitioning.param_with_axes('wo_bias',
1161
1162
1163
                                                         self.bias_init, (hidden_size,),
                                                         jnp.float32,
                                                         axes=self.bias_axes_2)
1164
1165
                bias_2 = bias_2.astype(self.dtype)
                out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
1166

1167
            out = checkpoint_name(out, ffn2_ckpt_name)
1168

1169
        return out, ln_output    # Output, layner_norm_output