__init__.py 1.14 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
5
6

from . import flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
7
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
8
9
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
10
from .sharding import MajorShardingType, ShardingResource, ShardingType
11

12
13
14
15
16
17
18
19
20
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

MajorShardingType = DeprecatedEnum(MajorShardingType,
                                   "MajorShardingType is deprecating in the near feature.")
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
    ShardingResource,
    "ShardingResource is renamed to MeshResource, and will be removed in the near feature.")
21
22

__all__ = [
23
24
25
26
27
28
29
30
31
32
33
    'NVTE_FP8_COLLECTION_NAME',
    'fp8_autocast',
    'update_collections',
    'update_fp8_metas',
    'get_delayed_scaling',
    'MeshResource',
    'MajorShardingType',
    'ShardingResource',
    'ShardingType',
    'flax',
    'praxis',
34
]