__init__.py 2.93 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
5
6

from . import flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
7
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
8
from .sharding import MajorShardingType, ShardingResource, ShardingType
9
10
11
12
from ..common.utils import deprecate_wrapper

extend_logical_axis_rules = deprecate_wrapper(
    flax.extend_logical_axis_rules,
13
14
    "extend_logical_axis_rules is moving to transformer_engine.jax.flax module"
    " and will be fully removed in the next release (v1.0.0).")
15
DenseGeneral = deprecate_wrapper(flax.DenseGeneral,
16
17
                                 "DenseGeneral is moving to transformer_engine.jax.flax module"
                                 " and will be fully removed in the next release (v1.0.0).")
18
LayerNorm = deprecate_wrapper(flax.LayerNorm,
19
20
                              "LayerNorm is moving to transformer_engine.jax.flax module"
                              " and will be fully removed in the next release (v1.0.0).")
21
22
LayerNormDenseGeneral = deprecate_wrapper(
    flax.LayerNormDenseGeneral,
23
24
    "LayerNormDenseGeneral is moving to transformer_engine.jax.flax module"
    " and will be fully removed in the next release (v1.0.0).")
25
LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP,
26
27
                                 "LayerNormMLP is moving to transformer_engine.jax.flax module"
                                 " and will be fully removed in the next release (v1.0.0).")
28
29
TransformerEngineBase = deprecate_wrapper(
    flax.TransformerEngineBase,
30
31
    "TransformerEngineBase is moving to transformer_engine.jax.flax module"
    " and will be fully removed in the next release (v1.0.0).")
32
MultiHeadAttention = deprecate_wrapper(
33
34
    flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module"
                             " and will be fully removed in the next release (v1.0.0).")
35
36
RelativePositionBiases = deprecate_wrapper(
    flax.RelativePositionBiases,
37
38
    "RelativePositionBiases is moving to transformer_engine.jax.flax module"
    " and will be fully removed in the next release (v1.0.0).")
39
TransformerLayer = deprecate_wrapper(
40
41
    flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module"
                           " and will be fully removed in the next release (v1.0.0).")
42
43
TransformerLayerType = deprecate_wrapper(
    flax.TransformerLayerType,
44
45
    "TransformerLayerType is moving to transformer_engine.jax.flax module"
    " and will be fully removed in the next release (v1.0.0).")
46
47
48

__all__ = [
    'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
49
50
51
    'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', 'DenseGeneral',
    'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase',
    'MultiHeadAttention', 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
52
]