module.py 41.6 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
9
import warnings
10
11
12
13
14
15
16
17
18
19
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random

20
21
from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
22
23
from ..layernorm import canonicalize_layernorm_type
from ..layernorm import layernorm, layernorm_fp8_dot
24
from ..mlp import layernrom_geglu_fp8_mlp, geglu
25
26
from ..softmax import is_softmax_kernel_available
from ..softmax import softmax, SoftmaxType
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


48
49
50
51
52
53
54
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
    if original_init is None:
        if not zero_centered_gamma:
            return nn.initializers.ones
    return nn.initializers.zeros


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
                                 bias_axes, dtype):
    scale = nn_partitioning.param_with_axes('scale',
                                            scale_init,
                                            shape,
                                            jnp.float32,
                                            axes=scale_axes)
    scale = jnp.asarray(scale, dtype)

    layernorm_type = canonicalize_layernorm_type(layernorm_type)
    if layernorm_type == 'layernorm':
        bias = nn_partitioning.param_with_axes('ln_bias',
                                               bias_init,
                                               shape,
                                               jnp.float32,
                                               axes=bias_axes)
        bias = jnp.asarray(bias, dtype)
    else:
        assert layernorm_type == 'rmsnorm'
        bias = None

    return scale, bias


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string

    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def _combine_biases(*masks: List[Array]):
    """Combine attention biases."""
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


104
class Softmax(nn.Module):    # pylint: disable=too-few-public-methods
105
106
    r"""
    Applies softmax over a mini-batch of inputs.
107
108
109
110
111
112
113
    The input's shape should be [batch, heads, q_seqlen, k_seqlen].

    .. code-block:: python
        shifted_input = input + bias
        masked_scaled = (1 - mask)*(shifted_input * scale_factor)
        softmax_mask = mask * -1e-10
        output = softmax(masked_scaled + softmax_mask)
114
115
116
117

    Parameters
    ----------
    scale_factor : float, default = 1.0
118
119
120
        Scalar for the input to softmax.
    softmax_type : SoftmaxType, default = SoftmaxType.SCALED
        Indicate the type of softmax.
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    """

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    @nn.compact
    def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
        batch = inputs.shape[0]
        heads = inputs.shape[1]
        q_seqlen = inputs.shape[2]
        k_seqlen = inputs.shape[3]
        dtype = inputs.dtype
        logits = inputs

        if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
                self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)):

            if bias is not None:
                logits = logits + bias.astype(dtype)

            mask_ = mask
            if self.softmax_type is not SoftmaxType.SCALED_MASKED:
                mask_ = None

145
            outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        else:
            attention_bias = None
            if mask is not None:
                attention_bias = lax.select(mask > 0,
                                            jnp.full(mask.shape, -1e10).astype(dtype),
                                            jnp.full(mask.shape, 0.).astype(dtype))

            if bias is not None:
                attention_bias = _combine_biases(attention_bias, bias)

            if attention_bias is not None:
                logits = logits + attention_bias.astype(dtype)

            # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
            # and kernel is unavailable, then try on pure scaled softmax custom calls.
            if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
                                           dtype):
163
                outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
164
            else:
165
                outputs = jax_nn.softmax(logits * self.scale_factor)
166
167
168
169

        return outputs


170
class LayerNorm(nn.Module):    # pylint: disable=too-few-public-methods
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    r"""
    Applies layer normalization over a mini-batch of inputs.
    There are two types of normalization supported by this module,
    regular and root mean square layer Normalization.

    The regular layer normalization is as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size of each input sample.

    The root mean square layer normalization (RMSNorm) is as described in
    the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__

    .. math::
        y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma

    .. math::
        RMS = \sqrt{\mathrm{E}[x^2]}

    :math:`\gamma` is learnable affine transform parameters of
    size of each input sample.

    Parameters
    ----------
    epsilon : float, default = 1e-6
200
        A value added to the denominator of layer normalization for numerical stability.
201
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
202
        Indicate the type of layer normalization.
203
204
205
206
207
208
209
210
211
212
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
213
        Used for initializing scale factors :math:`\gamma`.
214
215
216
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
217
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
218
    scale_axes : Tuple[str, ...], default = ('embed', )
219
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
220
    bias_init : Initializer, default = flax.linen.initializers.zeros
221
222
223
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
224
225
    bias_axes : Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
226
        only used when :attr:`layernorm_type='layernorm'`.
227
228
229
230
231

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
        the data type used to allocate the initial parameters.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
232
    transpose_batch_sequence : bool, default = False
233
234
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
235
236
237
238
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """
    epsilon: float = 1e-6
    layernorm_type: str = 'layernorm'
239
240
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
241
242
243
244
    scale_axes: Tuple[str, ...] = ('embed',)
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ('embed',)
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
245
    transpose_batch_sequence: bool = False
246
    sharding_type = None
247

248
    def __post_init__(self):
249
250
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
251
252
        super().__post_init__()

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies layer normalization to the input :attr:`inputs`.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
268
269
        warnings.warn("sharding_type of LayerNorm would be removed in the near feature",
                      DeprecationWarning)
270
271
272
273
274
275
276
277

        features = x.shape[-1]
        scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                      self.scale_init, self.scale_axes,
                                                      self.bias_init, self.bias_axes, self.dtype)
        return layernorm(x,
                         scale,
                         ln_bias,
278
279
                         layernorm_type=self.layernorm_type,
                         zero_centered_gamma=self.zero_centered_gamma,
280
                         epsilon=self.epsilon)
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320


class TransformerEngineBase(nn.Module):
    """
    Base class of transformer engine
    """

    @staticmethod
    def get_fp8_metas(num_of_gemm: int) -> List[jnp.ndarray]:
        """
        Get the FP8 metas
        """
        num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
        axes = ('fp8_meta_axis', 'fp8_meta_history')

        fp8_max = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                     FP8Helper.FP8_MAX_NAME,
                                                     FP8Helper.generate_fp8_max_array,
                                                     num_of_meta,
                                                     axes=axes)
        fp8_metas_amax = nn_partitioning.variable_with_axes(
            FP8Helper.FP8_COLLECTION_NAME,
            FP8Helper.FP8_AMAX_NAME,
            jnp.zeros, (num_of_meta, FP8Helper.AMAX_HISTORY_LEN),
            jnp.float32,
            axes=axes)
        fp8_metas_scale = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                             FP8Helper.FP8_SCALE_NAME,
                                                             jnp.ones, (num_of_meta, 1),
                                                             jnp.float32,
                                                             axes=axes)
        fp8_metas_scale_inv = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
                                                                 FP8Helper.FP8_SCALE_INV_NAME,
                                                                 jnp.ones, (num_of_meta, 1),
                                                                 jnp.float32,
                                                                 axes=axes)

        return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value

    @staticmethod
321
    def get_fp8_meta_package(num_of_gemm: int) -> FP8MetaPackage:
322
323
324
325
326
327
        """
        Get the FP8 metas
        """
        fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
            TransformerEngineBase.get_fp8_metas(num_of_gemm)

328
329
        return FP8MetaPackage(num_of_gemm, fp8_max, fp8_metas_amax, fp8_metas_scale,
                              fp8_metas_scale_inv)
330
331
332
333
334
335
336
337
338


class DenseGeneral(TransformerEngineBase):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    Parameters
    ----------
    features : Union[Iterable[int], int]
339
        The hidden size of each output sample.
340
341
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
342
343
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
344
    kernel_axes : Tuple[str, ...], default = ()
345
        The name of axes used to shard the weights with a corresponding mesh.
346
    use_bias: bool, default = False
347
348
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
349
    bias_init: Initializer, default = flax.linen.initializers.zeros
350
351
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
352
    bias_axes: Tuple[str, ...], default = ()
353
354
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
355
    axis:  Union[Iterable[int], int], default = -1
356
        An integer tuple with axes to apply the transformation on.
357
358
359
360

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
361
        The data type used to allocate the initial parameters.
362
    transpose_batch_sequence : bool, default = True
363
364
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
365
366
367
368
369
370
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    features: Union[Iterable[int], int]
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
371
    use_bias: bool = True
372
373
374
375
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
376
    transpose_batch_sequence: bool = False
377
    sharding_type = None
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inputs : jax.numpy.ndarray
            Input tensors.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        """
399
400
401
        warnings.warn("sharding_type of DenseGeneral would be removed in the near feature",
                      DeprecationWarning)

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
420
421
                                                   self.bias_init,
                                                   features,
422
                                                   jnp.float32,
423
                                                   axes=self.bias_axes)
424
            bias = bias.astype(self.dtype)
425
426
427
428
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))
429
        fp8_gemm_pkg = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
430
        if FP8Helper.is_fp8_enabled():
431
432
433
434
435
436
437
            fp8_gemm_pkg = \
                    TransformerEngineBase.get_fp8_meta_package(1)

        y = type_safe_dot_general(inputs,
                                  kernel,
                                  fp8_meta_pkg=fp8_gemm_pkg,
                                  contracting_dims=(axis, contract_ind))
438
439

        if bias is not None:
440
441
            bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
            y += jnp.reshape(bias, bias_shape)
442
443
444
445
446
447
448
449
450
451
        return y


class LayerNormDenseGeneral(TransformerEngineBase):
    r"""
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    features : Union[Iterable[int], int]
452
        The hidden size of each output sample.
453
    enable_layernorm: bool, default = True
454
        Indicate whether to enable layer normalization before linear transformation.
455
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
456
        Indicate the type of layer normalization.
457
    epsilon : float, default = 1e-6
458
        A value added to the denominator of layer normalization for numerical stability.
459
460
461
462
463
464
465
466
467
468
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`
    scale_init : Initializer, default = None
469
        Used for initializing scale factors :math:`\gamma`.
470
471
472
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
473
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
474
    scale_axes : Tuple[str, ...], default = ('embed', )
475
476
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
477
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
478
479
480
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
481
482
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
483
        It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
484
485
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
486
487
        Used for initializing weights.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
488
    kernel_axes : Tuple[str, ...], default = ()
489
        The name of axes used to shard the weights with a corresponding mesh.
490
    use_bias: bool, default = False
491
492
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
493
    bias_init: Initializer, default = flax.linen.initializers.zeros
494
495
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
496
    bias_axes: Tuple[str, ...], default = ()
497
498
        The name of axes used to shard bias with a corresponding mesh,
        only used when :attr:`use_bias=True`.
499
    return_layernorm_output: bool, default = True
500
        Indicate whether to return the output of layer normalization.
501
502
        If set False, return None as the second tensor in outputs.
    axis:  Union[Iterable[int], int], default = -1
503
        An integer tuple with axes to apply the transformation on.
504
505
506
507

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
508
        The data type used to allocate the initial parameters.
509
    transpose_batch_sequence : bool, default = True
510
511
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
512
513
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    depth_scaling: float, default = None
514
        The factor to scale the output from `DenseGeneral`. It should be a float
515
516
517
518
519
520
521
        value or None. When None is set, then no scaling is applied.
    """

    features: Union[Iterable[int], int]
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
522
523
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
524
525
526
527
528
529
530
531
532
533
534
535
536
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()
    return_layernorm_output: bool = True
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
    depth_scaling: float = None
537
    sharding_type = None
538
539
540
541

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
542
543
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
562
            If :attr:`return_layernorm_output=False`, then this would be None.
563
        """
564
565
566
        warnings.warn("sharding_type of LayerNormDenseGeneral would be removed in the near feature",
                      DeprecationWarning)

567
568
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
569
        fuse_layernorm = FP8Helper.is_fp8_enabled(
570
571
572
        ) and not self.return_layernorm_output and self.enable_layernorm

        if self.enable_layernorm:
573
            assert self.axis == -1    # Only support axis = =-1 at this moment
574
575
576
577
578
579
580
581
582
583
584
585
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
586
                              zero_centered_gamma=self.zero_centered_gamma,
587
                              epsilon=self.epsilon)
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        # DenseGeneral
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, y.ndim)

        kernel_shape = tuple(y.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.reshape(kernel, kernel_shape)

        contract_ind = tuple(range(0, len(axis)))

615
        fp8_meta_package = None
Ming-Xu Huang's avatar
Ming-Xu Huang committed
616
        if FP8Helper.is_fp8_enabled():
617
618
619
620
621
622
623
624
625
626
627
628
            fp8_meta_package = \
                    TransformerEngineBase.get_fp8_meta_package(1)

        if fuse_layernorm:
            z = layernorm_fp8_dot(y,
                                  kernel,
                                  scale,
                                  ln_bias,
                                  fp8_meta_package,
                                  self.layernorm_type,
                                  zero_centered_gamma=self.zero_centered_gamma,
                                  epsilon=self.epsilon)
629
        else:
630
631
632
633
            z = type_safe_dot_general(y,
                                      kernel,
                                      fp8_meta_pkg=fp8_meta_package,
                                      contracting_dims=(axis, contract_ind))
634
635
636
637

        bias = None
        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
638
639
                                                   self.bias_init,
                                                   features,
640
                                                   jnp.float32,
641
                                                   axes=self.bias_axes)
642
            bias = bias.astype(self.dtype)
643
644

        if bias is not None:
645
646
            bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
            z += jnp.reshape(bias, bias_shape)
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661

        if self.depth_scaling is not None:
            z = z / self.depth_scaling

        return z, ln_output    # dense_output, layer_norm_output


class LayerNormMLP(TransformerEngineBase):
    r"""
    Applies layer normalization on the input followed by the MLP module,
    consisting of 2 successive linear transformations, separated by given activations.

    Parameters
    ----------
    intermediate_dim: int, default = 2048
662
        Intermediate size to which input samples are projected.
663
    enable_layernorm: bool, default = True
664
        Indicate whether to enable layer normalization before linear transformation.
665
    layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
666
        Indicate the type of layer normalization.
667
    epsilon : float, default = 1e-6
668
        A value added to the denominator of layer normalization for numerical stability.
669
670
671
672
673
674
675
676
677
678
    zero_centered_gamma : bool, default = False
        If set to `True`, the LayerNorm formula changes to

        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
            (1 + \gamma) + \beta

        This parameter is only applicable for 'layernorm'.
        The default of `scale_init` will also be changed. See `scale_init`.
    scale_init : Initializer, default = None
679
        Used for initializing scale factors :math:`\gamma`.
680
681
682
        If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
        If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
        Otherwise, scale_init is `flax.linen.initializers.ones`.
683
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
684
    scale_axes : Tuple[str, ...], default = ('embed', )
685
686
        The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
        only used when :attr:`enable_layernorm=True`.
687
    ln_bias_init: Initializer, default = flax.linen.initializers.zeros
688
689
690
        Used for initializing shift factors :math:`\beta`,
        only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
691
692
    ln_bias_axes: Tuple[str, ...], default = ('embed', )
        The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
693
        Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
694
695
    kernel_init : Initializer, default =
        flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
696
697
        Used for initializing the weights of both linear transformations.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
698
    kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
699
        The name of axes used to shard the weights with a corresponding mesh for
700
701
        the weight of the first linear transformations.
    kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
702
        The name of axes used to shard the weights with a corresponding mesh for
703
704
        the weight of the second linear transformations.
    use_bias: bool, default = False
705
706
        Indicate whether to enable bias shifting.
        If set to False, the layer will not learn an additive bias.
707
    bias_init: Initializer, default = flax.linen.initializers.zeros
708
709
        Used for initializing bias, only used when :attr:`use_bias=True`.
        It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
710
    bias_axes_1: Tuple[str, ...], default = ('mlp',)
711
        The name of axes used to shard bias with a corresponding mesh  for
712
        the weight of the first linear transformations.
713
        Only used when :attr:`use_bias=True`.
714
    bias_axes_2: Tuple[str, ...], default = ('embed',)
715
        The name of axes used to shard bias with a corresponding mesh  for
716
        the weight of the second linear transformations.
717
        Only used when :attr:`use_bias=True`.
718
    return_layernorm_output: bool, default = True
719
        Indicate whether to return the output of layer normalization.
720
721
        If set False, return None as the second tensor in outputs.
    activations: Sequence[Union[str, Callable]], default = ('relu',)
722
        The sequence of activation functions to apply after the first linear transformation.
723
        Each activation has its own transformation layer.
724
725
    intermediate_dropout_rng_name: str, default = 'dropout'
        The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
726
    intermediate_dropout_rate: float, default = 0.1
727
        Dropout probability for the dropout op after the :attr:`activations`.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
728
729
    intermediate_hidden_dropout_dims: Sequence[int], default = ()
        Dimensions that will share the same dropout mask for hidden
730
    axis:  Union[Iterable[int], int], default = -1
731
        An integer tuple with axes to apply the transformation on.
732
733
734
735

    Optimization parameters
    -----------------------
    dtype : jax.numpy.dtype, default  = jax.numpy.float32
736
        The data type used to allocate the initial parameters.
737
    transpose_batch_sequence : bool, default = True
738
739
        Indicate whether the input tensors were switched axis of batch
        and sequence length dimension. If set to True, the input tensors
740
741
742
743
744
745
746
        should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
    """

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
    layernorm_type: str = 'layernorm'
    epsilon: float = 1e-6
747
748
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
749
750
751
752
753
754
755
756
    scale_axes: Tuple[str, ...] = ('embed',)
    ln_bias_init: Initializer = nn.initializers.zeros
    ln_bias_axes: Tuple[str, ...] = ('embed',)
    kernel_init: Initializer = None
    kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp')
    kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
757
    bias_axes_1: Tuple[str, ...] = ('act', 'mlp')
758
759
760
    bias_axes_2: Tuple[str, ...] = ('embed',)
    return_layernorm_output: bool = True
    activations: Sequence[Union[str, Callable]] = ('relu',)
761
    intermediate_dropout_rng_name: str = 'dropout'
762
    intermediate_dropout_rate: float = 0.1
Ming-Xu Huang's avatar
Ming-Xu Huang committed
763
    intermediate_hidden_dropout_dims: Sequence[int] = ()
764
765
766
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    transpose_batch_sequence: bool = True
767
    major_sharding_type = None
768
769
770
771

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
772
773
        self.scale_init = _obtain_default_layernorm_scale_init_if_need(
            self.scale_init, self.zero_centered_gamma)
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inputs: jax.numpy.ndarray
            Input tensor.
        deterministic: bool, default  = False
            Disable dropout ops if set to True.

        Returns
        -------
        outputs : jax.numpy.ndarray
            Output tensors.
        ln_outputs: jax.numpy.ndarray
            The output tensors of layer normalization.
794
            If :attr:`return_layernorm_output=False`, then this would be None.
795
        """
796
797
798
        warnings.warn("major_sharding_type of LayerNormMLP would be removed in the near feature",
                      DeprecationWarning)

799
800
        ln_output = None

Ming-Xu Huang's avatar
Ming-Xu Huang committed
801
        fuse_layernorm = FP8Helper.is_fp8_enabled(
802
803
        ) and not self.return_layernorm_output and self.enable_layernorm

804
805
806
807
808
809
810
811
812
813
        def is_geglu(acts):
            geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')]

            normalize_acts = []
            for act in acts:
                if not isinstance(act, str):
                    return False
                normalize_acts.append(act.lower())
            return normalize_acts in geglu_act_pool

814
        use_fused_ln_mlp = fuse_layernorm \
815
            and (not self.use_bias) and is_geglu(self.activations) \
816
817
818
819
                and (self.intermediate_dropout_rate < 1e-3)

        # LayerNorm
        if self.enable_layernorm:
820
821
            assert self.axis == -1    # Only support axis == -1 at this moment

822
823
824
825
826
827
828
829
830
831
832
833
            features = inputs.shape[-1]

            scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
                                                          self.scale_init, self.scale_axes,
                                                          self.ln_bias_init, self.ln_bias_axes,
                                                          self.dtype)

            if not fuse_layernorm:
                y = layernorm(inputs,
                              scale,
                              ln_bias,
                              layernorm_type=self.layernorm_type,
834
                              zero_centered_gamma=self.zero_centered_gamma,
835
                              epsilon=self.epsilon)
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
            else:
                assert not self.return_layernorm_output
                y = inputs
        else:
            y = inputs

        if self.return_layernorm_output:
            ln_output = y

        def kernel_1_init(key, num_kernels, stack_axis, *init_args):
            kernels = []
            for _ in range(num_kernels):
                key, init_key = jax_random.split(key)
                kernels.append(self.kernel_init(init_key, *init_args))
            return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)

        num_of_gemm = 2
853
854
855
856
        fp8_meta_package = None
        if FP8Helper.is_fp8_enabled():
            fp8_meta_package = \
                    TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
857

858
859
860
        num_activations = len(self.activations)
        axis = _canonicalize_tuple(self.axis)
        axis = _normalize_axes(axis, y.ndim)
861

862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
        intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
        kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
        kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
        kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
                                                   kernel_1_init,
                                                   num_activations,
                                                   -2,
                                                   kernel_1_each_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_1)
        kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
        hidden_size = inputs.shape[-1]
        hidden_size_tuple = _canonicalize_tuple(hidden_size)
        kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
        kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
        kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
                                                   self.kernel_init,
                                                   kernel_2_param_shape,
                                                   jnp.float32,
                                                   axes=self.kernel_axes_2)
        kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
        contract_ind = tuple(range(0, len(axis)))
884

885
886
887
888
        if use_fused_ln_mlp:
            assert self.axis == -1    # Only support axis = =-1 at this moment

            out = layernrom_geglu_fp8_mlp(y,
889
                                          scale,
890
891
                                          ln_bias, [kernel_1, kernel_2],
                                          fp8_meta_package,
892
                                          self.layernorm_type,
893
                                          zero_centered_gamma=self.zero_centered_gamma,
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
                                          epsilon=self.epsilon)
        else:    # not use_fused_ln_mlp

            # DenseGeneral 1
            gemm1_fp8_meta_package = None if fp8_meta_package is None \
                                     else fp8_meta_package.get_package_by_gemm_idx(0)
            if fuse_layernorm:
                x = layernorm_fp8_dot(y,
                                      kernel_1,
                                      scale,
                                      ln_bias,
                                      gemm1_fp8_meta_package,
                                      self.layernorm_type,
                                      zero_centered_gamma=self.zero_centered_gamma,
                                      epsilon=self.epsilon)
            else:
                x = type_safe_dot_general(y,
                                          kernel_1,
                                          fp8_meta_pkg=gemm1_fp8_meta_package,
                                          contracting_dims=(axis, contract_ind))
914
915
916
917

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wi_bias',
918
919
                                                       self.bias_init,
                                                       intermediate_dim,
920
                                                       jnp.float32,
921
                                                       axes=self.bias_axes_1)
922
                bias = bias.astype(self.dtype)
923
924
                bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
                x += jnp.reshape(bias, bias_shape)
925

926
927
928
            activations = []
            if is_geglu(self.activations):
                z = geglu(x)
929
930
931
932
933
934
935
936
            else:
                x = jnp.split(x, num_activations, axis=-2)
                for idx, act_fn in enumerate(self.activations):
                    x_i = _convert_to_activation_function(act_fn)(x[idx])
                    activations.append(x_i)
                z = functools.reduce(operator.mul, activations)
                z = jnp.reshape(z, (*z.shape[:-2], -1))

Ming-Xu Huang's avatar
Ming-Xu Huang committed
937
            z = nn.Dropout(rate=self.intermediate_dropout_rate,
938
939
                           broadcast_dims=self.intermediate_hidden_dropout_dims,
                           rng_collection=self.intermediate_dropout_rng_name)(
Ming-Xu Huang's avatar
Ming-Xu Huang committed
940
                               z, deterministic=deterministic)
941
942

            # DenseGeneral 2
943
944
945
946
947
948
949
            gemm2_fp8_meta_package = None if fp8_meta_package is None \
                                     else fp8_meta_package.get_package_by_gemm_idx(1)

            out = type_safe_dot_general(z,
                                        kernel_2,
                                        fp8_meta_pkg=gemm2_fp8_meta_package,
                                        contracting_dims=(axis, contract_ind))
950
951
952
953
954

            bias = None
            if self.use_bias:
                bias = nn_partitioning.param_with_axes('wo_bias',
                                                       self.bias_init, (hidden_size,),
955
                                                       jnp.float32,
956
                                                       axes=self.bias_axes_2)
957
                bias = bias.astype(self.dtype)
958
959
960
                out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))

        return out, ln_output    # Output, layner_norm_output