module.py 46.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
9
import warnings
10
11
12
13
14
15
16
17
18
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
19
from jax.ad_checkpoint import checkpoint_name
20

21
22
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
23
24
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
25
from ..mlp import layernorm_geglu_fp8_mlp, geglu
26
from ..mlp import layernorm_gelu_fp8_mlp, gelu
27
28
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
29
from ..sharding import with_sharding_constraint_by_logical_axes
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


51
52
53
54
55
56
57
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
    if original_init is None:
        if not zero_centered_gamma:
            return nn.initializers.ones
    return nn.initializers.zeros


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
                                 bias_axes, dtype):
    scale = nn_partitioning.param_with_axes('scale',
                                            scale_init,
                                            shape,
                                            jnp.float32,
                                            axes=scale_axes)
    scale = jnp.asarray(scale, dtype)

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'layernorm':
        bias = nn_partitioning.param_with_axes('ln_bias',
                                               bias_init,
                                               shape,
                                               jnp.float32,
                                               axes=bias_axes)
        bias = jnp.asarray(bias, dtype)
    else:
        assert layernorm_type == 'rmsnorm'
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


107
class Softmax(nn.Module):    # pylint: disable=too-few-public-methods
108
109
    r"""
    Applies softmax over a mini-batch of inputs.
110
111
112
113
114
115
116
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
117
118
119
120

    Parameters
    ----------
    scale_factor : float, default = 1.0
121
122
123
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

        if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
                self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)):

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

148
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        else:
            attention_bias = None
            if mask is not None:
                attention_bias = lax.select(mask > 0,
                                            jnp.full(mask.shape, -1e10).astype(dtype),
                                            jnp.full(mask.shape, 0.).astype(dtype))

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
            if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
                                           dtype):
166
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
167
            else:
168
                outputs = jax_nn.softmax(logits * self.scale_factor)
169
170
171
172

        return outputs


173
class LayerNorm(nn.Module):    # pylint: disable=too-few-public-methods
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
203
        A value added to the denominator of layer normalization for numerical stability.
204
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
205
        Indicate the type of layer normalization.
206
207
208
209
210
211
212
213
214
215
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
216
        Used for initializing scale factors :math:`\gamma`.
217
218
219
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
220
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
221
    scale_axes : Tuple[str, ...], default = ('embed', )
222
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
223
    bias_init : Initializer, default = flax.linen.initializers.zeros
224
225
226
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
227
228
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
229
        only used when :attr:`layernorm_type='layernorm'`.
230
231
232
233
234

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
        the data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
235
    transpose_batch_sequence : bool, default = False
236
237
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
238
239
240
241
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
    epsilon: float = 1e-6
    layernorm_type: str = 'layernorm'
242
243
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
244
245
246
247
    scale_axes: Tuple[str, ...] = ('embed',)
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ('embed',)
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
248
    transpose_batch_sequence: bool = False
249
    sharding_type = None
250

251
    def __post_init__(self):
252
253
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
254
255
        super().__post_init__()

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
271
272
        warnings.warn("sharding_type of LayerNorm would be removed in the near feature",
                      DeprecationWarning)
273
274
275
276
277
278
279
280

        features = x.shape[-1]
        scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                      self.scale_init, self.scale_axes,
                                                      self.bias_init, self.bias_axes, self.dtype)
        return layernorm(x,
                         scale,
                         ln_bias,
281
282
                         layernorm_type=self.layernorm_type,
                         zero_centered_gamma=self.zero_centered_gamma,
283
                         epsilon=self.epsilon)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323


class TransformerEngineBase(nn.Module):
    """
    Base class of transformer engine
    """

    @staticmethod
    def get_fp8_metas(num_of_gemm: int) -> List[jnp.ndarray]:
        """
        Get the FP8 metas
        """
        num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
        axes = ('fp8_meta_axis', 'fp8_meta_history')

        fp8_max = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                     FP8Helper.FP8_MAX_NAME,
                                                     FP8Helper.generate_fp8_max_array,
                                                     num_of_meta,
                                                     axes=axes)
        fp8_metas_amax = nn_partitioning.variable_with_axes(
            FP8Helper.FP8_COLLECTION_NAME,
            FP8Helper.FP8_AMAX_NAME,
            jnp.zeros, (num_of_meta, FP8Helper.AMAX_HISTORY_LEN),
            jnp.float32,
            axes=axes)
        fp8_metas_scale = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                             FP8Helper.FP8_SCALE_NAME,
                                                             jnp.ones, (num_of_meta, 1),
                                                             jnp.float32,
                                                             axes=axes)
        fp8_metas_scale_inv = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                                 FP8Helper.FP8_SCALE_INV_NAME,
                                                                 jnp.ones, (num_of_meta, 1),
                                                                 jnp.float32,
                                                                 axes=axes)

        return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value

    @staticmethod
324
    def get_fp8_meta_package(num_of_gemm: int) -> FP8MetaPackage:
325
326
327
328
329
330
        """
        Get the FP8 metas
        """
        fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
            TransformerEngineBase.get_fp8_metas(num_of_gemm)

331
332
        return FP8MetaPackage(num_of_gemm, fp8_max, fp8_metas_amax, fp8_metas_scale,
                              fp8_metas_scale_inv)
333
334
335
336
337
338
339
340
341


class DenseGeneral(TransformerEngineBase):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    Parameters
    ----------
    features : Union[Iterable[int], int]
342
        The hidden size of each output sample.
343
344
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
345
346
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
347
    kernel_axes : Tuple[str, ...], default = ()
348
        The name of axes used to shard the weights with a corresponding mesh.
349
    use_bias: bool, default = False
350
351
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
352
    bias_init: Initializer, default = flax.linen.initializers.zeros
353
354
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
355
    bias_axes: Tuple[str, ...], default = ()
356
357
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
358
    axis:  Union[Iterable[int], int], default = -1
359
        An integer tuple with axes to apply the transformation on.
360
361
362
363

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
364
        The data type used to allocate the initial parameters.
365
    transpose_batch_sequence : bool, default = True
366
367
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
368
369
370
371
372
373
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
374
    use_bias: bool = True
375
376
377
378
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
379
    transpose_batch_sequence: bool = False
380
    sharding_type = None
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
402
403
404
        warnings.warn("sharding_type of DenseGeneral would be removed in the near feature",
                      DeprecationWarning)

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
423
424
                                                   self.bias_init,
                                                   features,
425
                                                   jnp.float32,
426
                                                   axes=self.bias_axes)
427
            bias = bias.astype(self.dtype)
428
429
430
431
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
432
        fp8_gemm_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
433
        if FP8Helper.is_fp8_enabled():
434
435
436
437
438
439
440
            fp8_gemm_pkg = \
                    TransformerEngineBase.get_fp8_meta_package(1)

        y = type_safe_dot_general(inputs,
                                  kernel,
                                  fp8_meta_pkg=fp8_gemm_pkg,
                                  contracting_dims=(axis, contract_ind))
441
442

        if bias is not None:
443
444
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
445
446
447
448
449
450
451
452
453
454
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
455
        The hidden size of each output sample.
456
    enable_layernorm: bool, default = True
457
        Indicate whether to enable layer normalization before linear transformation.
458
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
459
        Indicate the type of layer normalization.
460
    epsilon : float, default = 1e-6
461
        A value added to the denominator of layer normalization for numerical stability.
462
463
464
465
466
467
468
469
470
471
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
472
        Used for initializing scale factors :math:`\gamma`.
473
474
475
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
476
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
477
    scale_axes : Tuple[str, ...], default = ('embed', )
478
479
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
480
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
481
482
483
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
484
485
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
486
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
487
488
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
489
490
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
491
    kernel_axes : Tuple[str, ...], default = ()
492
        The name of axes used to shard the weights with a corresponding mesh.
493
    use_bias: bool, default = False
494
495
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
496
    bias_init: Initializer, default = flax.linen.initializers.zeros
497
498
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
499
    bias_axes: Tuple[str, ...], default = ()
500
501
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
502
    return_layernorm_output: bool, default = True
503
        Indicate whether to return the output of layer normalization.
504
505
        If set False, return None as the second tensor in outputs.
    axis:  Union[Iterable[int], int], default = -1
506
        An integer tuple with axes to apply the transformation on.
507
508
509
510
511
512
513
514
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
515
516
517
518

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
519
        The data type used to allocate the initial parameters.
520
    transpose_batch_sequence : bool, default = True
521
522
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
523
524
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
525
        The factor to scale the output from `DenseGeneral`. It should be a float
526
527
528
529
530
531
532
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
533
534
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
535
536
537
538
539
540
541
542
543
544
545
546
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
547
548
    layernorm_input_axes: Tuple[str, ...] = None
    dot_input_axes: Tuple[str, ...] = None
549
    depth_scaling: float = None
550
    sharding_type = None
551
552
553
554

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
555
556
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
575
            If :attr:`return_layernorm_output=False`, then this would be None.
576
        """
577
578
579
        warnings.warn("sharding_type of LayerNormDenseGeneral would be removed in the near feature",
                      DeprecationWarning)

580
581
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
582
        fuse_layernorm = FP8Helper.is_fp8_enabled(
583
584
585
        ) and not self.return_layernorm_output and self.enable_layernorm

        if self.enable_layernorm:
586
587
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)

588
            assert self.axis == -1    # Only support axis = =-1 at this moment
589
590
591
592
593
594
595
596
597
598
599
600
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
601
                              zero_centered_gamma=self.zero_centered_gamma,
602
                              epsilon=self.epsilon)
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        contract_ind = tuple(range(0, len(axis)))

630
        fp8_meta_package = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
631
        if FP8Helper.is_fp8_enabled():
632
633
634
635
636
637
638
639
640
641
642
            fp8_meta_package = \
                    TransformerEngineBase.get_fp8_meta_package(1)

        if fuse_layernorm:
            z = layernorm_fp8_dot(y,
                                  kernel,
                                  scale,
                                  ln_bias,
                                  fp8_meta_package,
                                  self.layernorm_type,
                                  zero_centered_gamma=self.zero_centered_gamma,
643
644
645
                                  epsilon=self.epsilon,
                                  layernorm_input_axes=self.layernorm_input_axes,
                                  dot_input_axes=self.dot_input_axes)
646
        else:
647
            y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
648
649
650
651
            z = type_safe_dot_general(y,
                                      kernel,
                                      fp8_meta_pkg=fp8_meta_package,
                                      contracting_dims=(axis, contract_ind))
652
653
654
655

        bias = None
        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
656
657
                                                   self.bias_init,
                                                   features,
658
                                                   jnp.float32,
659
                                                   axes=self.bias_axes)
660
            bias = bias.astype(self.dtype)
661
662

        if bias is not None:
663
664
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

        return z, ln_output    # dense_output, layer_norm_output


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
680
        Intermediate size to which input samples are projected.
681
    enable_layernorm: bool, default = True
682
        Indicate whether to enable layer normalization before linear transformation.
683
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
684
        Indicate the type of layer normalization.
685
    epsilon : float, default = 1e-6
686
        A value added to the denominator of layer normalization for numerical stability.
687
688
689
690
691
692
693
694
695
696
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
697
        Used for initializing scale factors :math:`\gamma`.
698
699
700
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
701
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
702
    scale_axes : Tuple[str, ...], default = ('embed', )
703
704
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
705
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
706
707
708
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
709
710
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
711
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
712
713
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
714
715
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
716
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
717
        The name of axes used to shard the weights with a corresponding mesh for
718
719
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
720
        The name of axes used to shard the weights with a corresponding mesh for
721
722
        the weight of the second linear transformations.
    use_bias: bool, default = False
723
724
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
725
    bias_init: Initializer, default = flax.linen.initializers.zeros
726
727
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
728
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
729
        The name of axes used to shard bias with a corresponding mesh  for
730
        the weight of the first linear transformations.
731
        Only used when :attr:`use_bias=True`.
732
    bias_axes_2: Tuple[str, ...], default = ('embed',)
733
        The name of axes used to shard bias with a corresponding mesh  for
734
        the weight of the second linear transformations.
735
        Only used when :attr:`use_bias=True`.
736
    return_layernorm_output: bool, default = True
737
        Indicate whether to return the output of layer normalization.
738
739
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
740
        The sequence of activation functions to apply after the first linear transformation.
741
        Each activation has its own transformation layer.
742
743
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
744
    intermediate_dropout_rate: float, default = 0.1
745
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
746
747
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
748
    axis:  Union[Iterable[int], int], default = -1
749
        An integer tuple with axes to apply the transformation on.
750
751
752
753
754
755
756
757
758
759
760
761
    layernorm_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of layernorm, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_1_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 1st dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
    dot_2_input_axes: Tuple[str, ...], default = None
        Indicate the logical axes of sharding constraint to the input of 2nd dot, like
        (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
        sharding constraint.
762
763
764
765

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
766
        The data type used to allocate the initial parameters.
767
    transpose_batch_sequence : bool, default = True
768
769
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
770
771
772
773
774
775
776
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
777
778
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
779
780
781
782
783
784
785
786
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp')
    kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
787
    bias_axes_1: Tuple[str, ...] = ('act', 'mlp')
788
789
790
    bias_axes_2: Tuple[str, ...] = ('embed',)
    return_layernorm_output: bool = True
    activations: Sequence[Union[str, Callable]] = ('relu',)
791
    intermediate_dropout_rng_name: str = 'dropout'
792
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
793
    intermediate_hidden_dropout_dims: Sequence[int] = ()
794
795
796
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
797
798
799
    layernorm_input_axes: Tuple[str, ...] = None
    dot_1_input_axes: Tuple[str, ...] = None
    dot_2_input_axes: Tuple[str, ...] = None
800
    major_sharding_type = None
801
802
803
804

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
805
806
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
827
            If :attr:`return_layernorm_output=False`, then this would be None.
828
        """
829
830
831
        warnings.warn("major_sharding_type of LayerNormMLP would be removed in the near feature",
                      DeprecationWarning)

832
833
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
834
        fuse_layernorm = FP8Helper.is_fp8_enabled(
835
836
        ) and not self.return_layernorm_output and self.enable_layernorm

837
838
839
840
841
842
843
844
        def is_geglu(acts):
            geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')]

            normalize_acts = []
            for act in acts:
                if not isinstance(act, str):
                    return False
                normalize_acts.append(act.lower())
845
            return tuple(normalize_acts) in geglu_act_pool
846

847
848
849
850
851
852
853
854
855
856
857
        def is_gelu(acts):
            geglu_act_pool = [('gelu',)]

            normalize_acts = []
            for act in acts:
                if not isinstance(act, str):
                    return False
                normalize_acts.append(act.lower())
            return tuple(normalize_acts) in geglu_act_pool

        use_fused_ln_geglu_mlp = fuse_layernorm \
858
            and (not self.use_bias) and is_geglu(self.activations) \
859
860
                and (self.intermediate_dropout_rate < 1e-3)

861
862
863
864
        use_fused_ln_gelu_mlp = fuse_layernorm \
            and self.use_bias and is_gelu(self.activations) \
                and (self.intermediate_dropout_rate < 1e-3)

865
866
        # LayerNorm
        if self.enable_layernorm:
867
            assert self.axis == -1    # Only support axis == -1 at this moment
868
            inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
869

870
871
872
873
874
875
876
877
878
879
880
881
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
882
                              zero_centered_gamma=self.zero_centered_gamma,
883
                              epsilon=self.epsilon)
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
            return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)

        num_of_gemm = 2
901
902
903
904
        fp8_meta_package = None
        if FP8Helper.is_fp8_enabled():
            fp8_meta_package = \
                    TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
905

906
907
908
        num_activations = len(self.activations)
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
909

910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
        kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
                                                   kernel_1_init,
                                                   num_activations,
                                                   -2,
                                                   kernel_1_each_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_1)
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
        kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
                                                   self.kernel_init,
                                                   kernel_2_param_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_2)
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
        contract_ind = tuple(range(0, len(axis)))
932

933
934
935
936
        ffn1_ckpt_name = 'ffn1'
        ffn2_ckpt_name = 'ffn2'

        if use_fused_ln_geglu_mlp:
937
938
            assert self.axis == -1    # Only support axis = =-1 at this moment

939
            out = layernorm_geglu_fp8_mlp(y,
940
                                          scale,
941
942
                                          ln_bias, [kernel_1, kernel_2],
                                          fp8_meta_package,
943
                                          self.layernorm_type,
944
                                          zero_centered_gamma=self.zero_centered_gamma,
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
                                          epsilon=self.epsilon,
                                          layernorm_input_axes=self.layernorm_input_axes,
                                          dot_1_input_axes=self.dot_1_input_axes,
                                          dot_2_input_axes=self.dot_2_input_axes,
                                          ffn1_ckpt_name=ffn1_ckpt_name,
                                          ffn2_ckpt_name=ffn2_ckpt_name)
        elif use_fused_ln_gelu_mlp:
            assert self.axis == -1    # Only support axis = =-1 at this moment

            bias_1 = nn_partitioning.param_with_axes('wi_bias',
                                                     self.bias_init,
                                                     intermediate_dim,
                                                     jnp.float32,
                                                     axes=self.bias_axes_1)
            bias_1 = bias_1.astype(self.dtype)

            bias_2 = nn_partitioning.param_with_axes('wo_bias',
                                                     self.bias_init, (hidden_size,),
                                                     jnp.float32,
                                                     axes=self.bias_axes_2)
            bias_2 = bias_2.astype(self.dtype)

            out = layernorm_gelu_fp8_mlp(y,
                                         scale,
                                         ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
                                         fp8_meta_package,
                                         self.layernorm_type,
                                         zero_centered_gamma=self.zero_centered_gamma,
                                         epsilon=self.epsilon,
                                         layernorm_input_axes=self.layernorm_input_axes,
                                         dot_1_input_axes=self.dot_1_input_axes,
                                         dot_2_input_axes=self.dot_2_input_axes,
                                         ffn1_ckpt_name=ffn1_ckpt_name,
                                         ffn2_ckpt_name=ffn2_ckpt_name)
        else:    # not use_fused_ln_geglu_mlp
980
981
982
983
984
985
986
987
988
989
990
991

            # DenseGeneral 1
            gemm1_fp8_meta_package = None if fp8_meta_package is None \
                                     else fp8_meta_package.get_package_by_gemm_idx(0)
            if fuse_layernorm:
                x = layernorm_fp8_dot(y,
                                      kernel_1,
                                      scale,
                                      ln_bias,
                                      gemm1_fp8_meta_package,
                                      self.layernorm_type,
                                      zero_centered_gamma=self.zero_centered_gamma,
992
993
994
                                      epsilon=self.epsilon,
                                      layernorm_input_axes=self.layernorm_input_axes,
                                      dot_input_axes=self.dot_1_input_axes)
995
            else:
996
                y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
997
998
999
1000
                x = type_safe_dot_general(y,
                                          kernel_1,
                                          fp8_meta_pkg=gemm1_fp8_meta_package,
                                          contracting_dims=(axis, contract_ind))
1001
1002
1003
1004

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wi_bias',
1005
1006
                                                       self.bias_init,
                                                       intermediate_dim,
1007
                                                       jnp.float32,
1008
                                                       axes=self.bias_axes_1)
1009
                bias = bias.astype(self.dtype)
1010
1011
                bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
                x += jnp.reshape(bias, bias_shape)
1012

1013
            x = checkpoint_name(x, ffn1_ckpt_name)
1014

1015
1016
1017
            activations = []
            if is_geglu(self.activations):
                z = geglu(x)
1018
1019
1020
            elif is_gelu(self.activations):
                z = gelu(x)
                z = jnp.reshape(z, (*z.shape[:-2], -1))
1021
1022
1023
1024
1025
1026
1027
1028
            else:
                x = jnp.split(x, num_activations, axis=-2)
                for idx, act_fn in enumerate(self.activations):
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
                z = jnp.reshape(z, (*z.shape[:-2], -1))

Ming-Xu Huang's avatar
Ming-Xu Huang committed
1029
            z = nn.Dropout(rate=self.intermediate_dropout_rate,
1030
1031
                           broadcast_dims=self.intermediate_hidden_dropout_dims,
                           rng_collection=self.intermediate_dropout_rng_name)(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
1032
                               z, deterministic=deterministic)
1033

1034
1035
            z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)

1036
            # DenseGeneral 2
1037
1038
1039
1040
1041
1042
1043
            gemm2_fp8_meta_package = None if fp8_meta_package is None \
                                     else fp8_meta_package.get_package_by_gemm_idx(1)

            out = type_safe_dot_general(z,
                                        kernel_2,
                                        fp8_meta_pkg=gemm2_fp8_meta_package,
                                        contracting_dims=(axis, contract_ind))
1044
1045
1046
1047
1048

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wo_bias',
                                                       self.bias_init, (hidden_size,),
1049
                                                       jnp.float32,
1050
                                                       axes=self.bias_axes_2)
1051
                bias = bias.astype(self.dtype)
1052
1053
                out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))

1054
            out = checkpoint_name(out, ffn2_ckpt_name)
1055

1056
        return out, ln_output    # Output, layner_norm_output