__init__.py 1.73 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
5

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# pylint: disable=wrong-import-position,wrong-import-order

import ctypes

from transformer_engine.common import get_te_path
from transformer_engine.common import _get_sys_extension


def _load_library():
    """Load shared library with Transformer Engine C extensions"""
    extension = _get_sys_extension()
    try:
        so_dir = get_te_path() / "transformer_engine"
        so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))
    except StopIteration:
        so_dir = get_te_path()
        so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))

    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)


_TE_JAX_LIB_CTYPES = _load_library()
28
from . import flax
29
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
30
31
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
32
from .sharding import MajorShardingType, ShardingResource, ShardingType
33

34
35
36
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

37
38
39
MajorShardingType = DeprecatedEnum(
    MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
40
41
42
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
    ShardingResource,
43
44
    "ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
45
46

__all__ = [
47
48
49
50
51
52
53
54
55
56
    "NVTE_FP8_COLLECTION_NAME",
    "fp8_autocast",
    "update_collections",
    "get_delayed_scaling",
    "MeshResource",
    "MajorShardingType",
    "ShardingResource",
    "ShardingType",
    "flax",
    "praxis",
57
]