__init__.py 1.76 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
5

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# pylint: disable=wrong-import-position,wrong-import-order

import ctypes

from transformer_engine.common import get_te_path
from transformer_engine.common import _get_sys_extension


def _load_library():
    """Load shared library with Transformer Engine C extensions"""
    extension = _get_sys_extension()
    try:
        so_dir = get_te_path() / "transformer_engine"
        so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))
    except StopIteration:
        so_dir = get_te_path()
        so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))

    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)


_TE_JAX_LIB_CTYPES = _load_library()
28
from . import flax
29
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
30
31
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
32
from .sharding import MajorShardingType, ShardingResource, ShardingType
33

34
35
36
37
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

MajorShardingType = DeprecatedEnum(MajorShardingType,
38
                                "MajorShardingType is deprecating in the near feature.")
39
40
41
42
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
    ShardingResource,
    "ShardingResource is renamed to MeshResource, and will be removed in the near feature.")
43
44

__all__ = [
45
46
47
48
49
50
51
52
53
54
    'NVTE_FP8_COLLECTION_NAME',
    'fp8_autocast',
    'update_collections',
    'get_delayed_scaling',
    'MeshResource',
    'MajorShardingType',
    'ShardingResource',
    'ShardingType',
    'flax',
    'praxis',
55
]