module.py 41.7 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
9
import warnings
10
11
12
13
14
15
16
17
18
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
19
from jax.ad_checkpoint import checkpoint_name
20

21
22
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
23
24
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
25
from ..mlp import layernrom_geglu_fp8_mlp, geglu
26
27
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


49
50
51
52
53
54
55
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
    if original_init is None:
        if not zero_centered_gamma:
            return nn.initializers.ones
    return nn.initializers.zeros


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
                                 bias_axes, dtype):
    scale = nn_partitioning.param_with_axes('scale',
                                            scale_init,
                                            shape,
                                            jnp.float32,
                                            axes=scale_axes)
    scale = jnp.asarray(scale, dtype)

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'layernorm':
        bias = nn_partitioning.param_with_axes('ln_bias',
                                               bias_init,
                                               shape,
                                               jnp.float32,
                                               axes=bias_axes)
        bias = jnp.asarray(bias, dtype)
    else:
        assert layernorm_type == 'rmsnorm'
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


105
class Softmax(nn.Module):    # pylint: disable=too-few-public-methods
106
107
    r"""
    Applies softmax over a mini-batch of inputs.
108
109
110
111
112
113
114
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
115
116
117
118

    Parameters
    ----------
    scale_factor : float, default = 1.0
119
120
121
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

        if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
                self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)):

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

146
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        else:
            attention_bias = None
            if mask is not None:
                attention_bias = lax.select(mask > 0,
                                            jnp.full(mask.shape, -1e10).astype(dtype),
                                            jnp.full(mask.shape, 0.).astype(dtype))

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
            if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
                                           dtype):
164
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
165
            else:
166
                outputs = jax_nn.softmax(logits * self.scale_factor)
167
168
169
170

        return outputs


171
class LayerNorm(nn.Module):    # pylint: disable=too-few-public-methods
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
201
        A value added to the denominator of layer normalization for numerical stability.
202
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
203
        Indicate the type of layer normalization.
204
205
206
207
208
209
210
211
212
213
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
214
        Used for initializing scale factors :math:`\gamma`.
215
216
217
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
218
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
219
    scale_axes : Tuple[str, ...], default = ('embed', )
220
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
221
    bias_init : Initializer, default = flax.linen.initializers.zeros
222
223
224
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
225
226
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
227
        only used when :attr:`layernorm_type='layernorm'`.
228
229
230
231
232

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
        the data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
233
    transpose_batch_sequence : bool, default = False
234
235
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
236
237
238
239
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
    epsilon: float = 1e-6
    layernorm_type: str = 'layernorm'
240
241
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
242
243
244
245
    scale_axes: Tuple[str, ...] = ('embed',)
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ('embed',)
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
246
    transpose_batch_sequence: bool = False
247
    sharding_type = None
248

249
    def __post_init__(self):
250
251
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
252
253
        super().__post_init__()

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
269
270
        warnings.warn("sharding_type of LayerNorm would be removed in the near feature",
                      DeprecationWarning)
271
272
273
274
275
276
277
278

        features = x.shape[-1]
        scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                      self.scale_init, self.scale_axes,
                                                      self.bias_init, self.bias_axes, self.dtype)
        return layernorm(x,
                         scale,
                         ln_bias,
279
280
                         layernorm_type=self.layernorm_type,
                         zero_centered_gamma=self.zero_centered_gamma,
281
                         epsilon=self.epsilon)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321


class TransformerEngineBase(nn.Module):
    """
    Base class of transformer engine
    """

    @staticmethod
    def get_fp8_metas(num_of_gemm: int) -> List[jnp.ndarray]:
        """
        Get the FP8 metas
        """
        num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
        axes = ('fp8_meta_axis', 'fp8_meta_history')

        fp8_max = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                     FP8Helper.FP8_MAX_NAME,
                                                     FP8Helper.generate_fp8_max_array,
                                                     num_of_meta,
                                                     axes=axes)
        fp8_metas_amax = nn_partitioning.variable_with_axes(
            FP8Helper.FP8_COLLECTION_NAME,
            FP8Helper.FP8_AMAX_NAME,
            jnp.zeros, (num_of_meta, FP8Helper.AMAX_HISTORY_LEN),
            jnp.float32,
            axes=axes)
        fp8_metas_scale = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                             FP8Helper.FP8_SCALE_NAME,
                                                             jnp.ones, (num_of_meta, 1),
                                                             jnp.float32,
                                                             axes=axes)
        fp8_metas_scale_inv = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                                 FP8Helper.FP8_SCALE_INV_NAME,
                                                                 jnp.ones, (num_of_meta, 1),
                                                                 jnp.float32,
                                                                 axes=axes)

        return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value

    @staticmethod
322
    def get_fp8_meta_package(num_of_gemm: int) -> FP8MetaPackage:
323
324
325
326
327
328
        """
        Get the FP8 metas
        """
        fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
            TransformerEngineBase.get_fp8_metas(num_of_gemm)

329
330
        return FP8MetaPackage(num_of_gemm, fp8_max, fp8_metas_amax, fp8_metas_scale,
                              fp8_metas_scale_inv)
331
332
333
334
335
336
337
338
339


class DenseGeneral(TransformerEngineBase):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    Parameters
    ----------
    features : Union[Iterable[int], int]
340
        The hidden size of each output sample.
341
342
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
343
344
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
345
    kernel_axes : Tuple[str, ...], default = ()
346
        The name of axes used to shard the weights with a corresponding mesh.
347
    use_bias: bool, default = False
348
349
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
350
    bias_init: Initializer, default = flax.linen.initializers.zeros
351
352
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
353
    bias_axes: Tuple[str, ...], default = ()
354
355
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
356
    axis:  Union[Iterable[int], int], default = -1
357
        An integer tuple with axes to apply the transformation on.
358
359
360
361

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
362
        The data type used to allocate the initial parameters.
363
    transpose_batch_sequence : bool, default = True
364
365
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
366
367
368
369
370
371
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
372
    use_bias: bool = True
373
374
375
376
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
377
    transpose_batch_sequence: bool = False
378
    sharding_type = None
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
400
401
402
        warnings.warn("sharding_type of DenseGeneral would be removed in the near feature",
                      DeprecationWarning)

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
421
422
                                                   self.bias_init,
                                                   features,
423
                                                   jnp.float32,
424
                                                   axes=self.bias_axes)
425
            bias = bias.astype(self.dtype)
426
427
428
429
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
430
        fp8_gemm_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
431
        if FP8Helper.is_fp8_enabled():
432
433
434
435
436
437
438
            fp8_gemm_pkg = \
                    TransformerEngineBase.get_fp8_meta_package(1)

        y = type_safe_dot_general(inputs,
                                  kernel,
                                  fp8_meta_pkg=fp8_gemm_pkg,
                                  contracting_dims=(axis, contract_ind))
439
440

        if bias is not None:
441
442
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
443
444
445
446
447
448
449
450
451
452
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
453
        The hidden size of each output sample.
454
    enable_layernorm: bool, default = True
455
        Indicate whether to enable layer normalization before linear transformation.
456
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
457
        Indicate the type of layer normalization.
458
    epsilon : float, default = 1e-6
459
        A value added to the denominator of layer normalization for numerical stability.
460
461
462
463
464
465
466
467
468
469
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
470
        Used for initializing scale factors :math:`\gamma`.
471
472
473
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
474
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
475
    scale_axes : Tuple[str, ...], default = ('embed', )
476
477
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
478
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
479
480
481
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
482
483
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
484
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
485
486
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
487
488
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
489
    kernel_axes : Tuple[str, ...], default = ()
490
        The name of axes used to shard the weights with a corresponding mesh.
491
    use_bias: bool, default = False
492
493
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
494
    bias_init: Initializer, default = flax.linen.initializers.zeros
495
496
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
497
    bias_axes: Tuple[str, ...], default = ()
498
499
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
500
    return_layernorm_output: bool, default = True
501
        Indicate whether to return the output of layer normalization.
502
503
        If set False, return None as the second tensor in outputs.
    axis:  Union[Iterable[int], int], default = -1
504
        An integer tuple with axes to apply the transformation on.
505
506
507
508

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
509
        The data type used to allocate the initial parameters.
510
    transpose_batch_sequence : bool, default = True
511
512
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
513
514
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
515
        The factor to scale the output from `DenseGeneral`. It should be a float
516
517
518
519
520
521
522
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
523
524
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
525
526
527
528
529
530
531
532
533
534
535
536
537
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
    depth_scaling: float = None
538
    sharding_type = None
539
540
541
542

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
543
544
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
563
            If :attr:`return_layernorm_output=False`, then this would be None.
564
        """
565
566
567
        warnings.warn("sharding_type of LayerNormDenseGeneral would be removed in the near feature",
                      DeprecationWarning)

568
569
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
570
        fuse_layernorm = FP8Helper.is_fp8_enabled(
571
572
573
        ) and not self.return_layernorm_output and self.enable_layernorm

        if self.enable_layernorm:
574
            assert self.axis == -1    # Only support axis = =-1 at this moment
575
576
577
578
579
580
581
582
583
584
585
586
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
587
                              zero_centered_gamma=self.zero_centered_gamma,
588
                              epsilon=self.epsilon)
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        contract_ind = tuple(range(0, len(axis)))

616
        fp8_meta_package = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
617
        if FP8Helper.is_fp8_enabled():
618
619
620
621
622
623
624
625
626
627
628
629
            fp8_meta_package = \
                    TransformerEngineBase.get_fp8_meta_package(1)

        if fuse_layernorm:
            z = layernorm_fp8_dot(y,
                                  kernel,
                                  scale,
                                  ln_bias,
                                  fp8_meta_package,
                                  self.layernorm_type,
                                  zero_centered_gamma=self.zero_centered_gamma,
                                  epsilon=self.epsilon)
630
        else:
631
632
633
634
            z = type_safe_dot_general(y,
                                      kernel,
                                      fp8_meta_pkg=fp8_meta_package,
                                      contracting_dims=(axis, contract_ind))
635
636
637
638

        bias = None
        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
639
640
                                                   self.bias_init,
                                                   features,
641
                                                   jnp.float32,
642
                                                   axes=self.bias_axes)
643
            bias = bias.astype(self.dtype)
644
645

        if bias is not None:
646
647
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

        return z, ln_output    # dense_output, layer_norm_output


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
663
        Intermediate size to which input samples are projected.
664
    enable_layernorm: bool, default = True
665
        Indicate whether to enable layer normalization before linear transformation.
666
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
667
        Indicate the type of layer normalization.
668
    epsilon : float, default = 1e-6
669
        A value added to the denominator of layer normalization for numerical stability.
670
671
672
673
674
675
676
677
678
679
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
680
        Used for initializing scale factors :math:`\gamma`.
681
682
683
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
684
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
685
    scale_axes : Tuple[str, ...], default = ('embed', )
686
687
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
688
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
689
690
691
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
692
693
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
694
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
695
696
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
697
698
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
699
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
700
        The name of axes used to shard the weights with a corresponding mesh for
701
702
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
703
        The name of axes used to shard the weights with a corresponding mesh for
704
705
        the weight of the second linear transformations.
    use_bias: bool, default = False
706
707
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
708
    bias_init: Initializer, default = flax.linen.initializers.zeros
709
710
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
711
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
712
        The name of axes used to shard bias with a corresponding mesh  for
713
        the weight of the first linear transformations.
714
        Only used when :attr:`use_bias=True`.
715
    bias_axes_2: Tuple[str, ...], default = ('embed',)
716
        The name of axes used to shard bias with a corresponding mesh  for
717
        the weight of the second linear transformations.
718
        Only used when :attr:`use_bias=True`.
719
    return_layernorm_output: bool, default = True
720
        Indicate whether to return the output of layer normalization.
721
722
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
723
        The sequence of activation functions to apply after the first linear transformation.
724
        Each activation has its own transformation layer.
725
726
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
727
    intermediate_dropout_rate: float, default = 0.1
728
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
729
730
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
731
    axis:  Union[Iterable[int], int], default = -1
732
        An integer tuple with axes to apply the transformation on.
733
734
735
736

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
737
        The data type used to allocate the initial parameters.
738
    transpose_batch_sequence : bool, default = True
739
740
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
741
742
743
744
745
746
747
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
748
749
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
750
751
752
753
754
755
756
757
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp')
    kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
758
    bias_axes_1: Tuple[str, ...] = ('act', 'mlp')
759
760
761
    bias_axes_2: Tuple[str, ...] = ('embed',)
    return_layernorm_output: bool = True
    activations: Sequence[Union[str, Callable]] = ('relu',)
762
    intermediate_dropout_rng_name: str = 'dropout'
763
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
764
    intermediate_hidden_dropout_dims: Sequence[int] = ()
765
766
767
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
768
    major_sharding_type = None
769
770
771
772

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
773
774
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
795
            If :attr:`return_layernorm_output=False`, then this would be None.
796
        """
797
798
799
        warnings.warn("major_sharding_type of LayerNormMLP would be removed in the near feature",
                      DeprecationWarning)

800
801
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
802
        fuse_layernorm = FP8Helper.is_fp8_enabled(
803
804
        ) and not self.return_layernorm_output and self.enable_layernorm

805
806
807
808
809
810
811
812
813
814
        def is_geglu(acts):
            geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')]

            normalize_acts = []
            for act in acts:
                if not isinstance(act, str):
                    return False
                normalize_acts.append(act.lower())
            return normalize_acts in geglu_act_pool

815
        use_fused_ln_mlp = fuse_layernorm \
816
            and (not self.use_bias) and is_geglu(self.activations) \
817
818
819
820
                and (self.intermediate_dropout_rate < 1e-3)

        # LayerNorm
        if self.enable_layernorm:
821
822
            assert self.axis == -1    # Only support axis == -1 at this moment

823
824
825
826
827
828
829
830
831
832
833
834
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
835
                              zero_centered_gamma=self.zero_centered_gamma,
836
                              epsilon=self.epsilon)
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
            return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)

        num_of_gemm = 2
854
855
856
857
        fp8_meta_package = None
        if FP8Helper.is_fp8_enabled():
            fp8_meta_package = \
                    TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
858

859
860
861
        num_activations = len(self.activations)
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
862

863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
        kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
                                                   kernel_1_init,
                                                   num_activations,
                                                   -2,
                                                   kernel_1_each_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_1)
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
        kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
                                                   self.kernel_init,
                                                   kernel_2_param_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_2)
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
        contract_ind = tuple(range(0, len(axis)))
885

886
887
888
889
        if use_fused_ln_mlp:
            assert self.axis == -1    # Only support axis = =-1 at this moment

            out = layernrom_geglu_fp8_mlp(y,
890
                                          scale,
891
892
                                          ln_bias, [kernel_1, kernel_2],
                                          fp8_meta_package,
893
                                          self.layernorm_type,
894
                                          zero_centered_gamma=self.zero_centered_gamma,
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
                                          epsilon=self.epsilon)
        else:    # not use_fused_ln_mlp

            # DenseGeneral 1
            gemm1_fp8_meta_package = None if fp8_meta_package is None \
                                     else fp8_meta_package.get_package_by_gemm_idx(0)
            if fuse_layernorm:
                x = layernorm_fp8_dot(y,
                                      kernel_1,
                                      scale,
                                      ln_bias,
                                      gemm1_fp8_meta_package,
                                      self.layernorm_type,
                                      zero_centered_gamma=self.zero_centered_gamma,
                                      epsilon=self.epsilon)
            else:
                x = type_safe_dot_general(y,
                                          kernel_1,
                                          fp8_meta_pkg=gemm1_fp8_meta_package,
                                          contracting_dims=(axis, contract_ind))
915
916
917
918

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wi_bias',
919
920
                                                       self.bias_init,
                                                       intermediate_dim,
921
                                                       jnp.float32,
922
                                                       axes=self.bias_axes_1)
923
                bias = bias.astype(self.dtype)
924
925
                bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
                x += jnp.reshape(bias, bias_shape)
926

927
928
            x = checkpoint_name(x, 'ffn1')

929
930
931
            activations = []
            if is_geglu(self.activations):
                z = geglu(x)
932
933
934
935
936
937
938
939
            else:
                x = jnp.split(x, num_activations, axis=-2)
                for idx, act_fn in enumerate(self.activations):
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
                z = jnp.reshape(z, (*z.shape[:-2], -1))

Ming-Xu Huang's avatar
Ming-Xu Huang committed
940
            z = nn.Dropout(rate=self.intermediate_dropout_rate,
941
942
                           broadcast_dims=self.intermediate_hidden_dropout_dims,
                           rng_collection=self.intermediate_dropout_rng_name)(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
943
                               z, deterministic=deterministic)
944
945

            # DenseGeneral 2
946
947
948
949
950
951
952
            gemm2_fp8_meta_package = None if fp8_meta_package is None \
                                     else fp8_meta_package.get_package_by_gemm_idx(1)

            out = type_safe_dot_general(z,
                                        kernel_2,
                                        fp8_meta_pkg=gemm2_fp8_meta_package,
                                        contracting_dims=(axis, contract_ind))
953
954
955
956
957

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wo_bias',
                                                       self.bias_init, (hidden_size,),
958
                                                       jnp.float32,
959
                                                       axes=self.bias_axes_2)
960
                bias = bias.astype(self.dtype)
961
962
                out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))

963
964
            out = checkpoint_name(out, 'ffn2')

965
        return out, ln_output    # Output, layner_norm_output