__init__.py 3.33 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
5

6
7
# pylint: disable=wrong-import-position,wrong-import-order

8
import sys
9
import logging
10
11
import importlib
import importlib.util
12
import ctypes
13
from importlib.metadata import version
14

15
from transformer_engine.common import get_te_path, is_package_installed
16
17
from transformer_engine.common import _get_sys_extension

18
19
_logger = logging.getLogger(__name__)

20
21
22

def _load_library():
    """Load shared library with Transformer Engine C extensions"""
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    module_name = "transformer_engine_jax"

    if is_package_installed(module_name):
        assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
        assert is_package_installed(
            "transformer_engine_cu12"
        ), "Could not find `transformer-engine-cu12`."
        assert (
            version(module_name)
            == version("transformer-engine")
            == version("transformer-engine-cu12")
        ), (
            "TransformerEngine package version mismatch. Found"
            f" {module_name} v{version(module_name)}, transformer-engine"
            f" v{version('transformer-engine')}, and transformer-engine-cu12"
38
39
            f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
            "'pip3 install transformer-engine[jax]==VERSION'"
40
41
42
43
        )

    if is_package_installed("transformer-engine-cu12"):
        if not is_package_installed(module_name):
44
            _logger.info(
45
46
                "Could not find package %s. Install transformer-engine using "
                "'pip3 install transformer-engine[jax]==VERSION'",
47
48
49
                module_name,
            )

50
51
52
    extension = _get_sys_extension()
    try:
        so_dir = get_te_path() / "transformer_engine"
53
        so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
54
    except StopIteration:
55
56
57
58
59
60
        try:
            so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
            so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
        except StopIteration:
            so_dir = get_te_path()
            so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
61

62
63
64
65
    spec = importlib.util.spec_from_file_location(module_name, so_path)
    solib = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = solib
    spec.loader.exec_module(solib)
66
67


68
_load_library()
69
from . import flax
70
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
71
72
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
73
from .sharding import MajorShardingType, ShardingResource, ShardingType
74

75
76
77
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

78
79
80
MajorShardingType = DeprecatedEnum(
    MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
81
82
83
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
    ShardingResource,
84
85
    "ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
86
87

__all__ = [
88
89
90
91
92
93
94
95
96
97
    "NVTE_FP8_COLLECTION_NAME",
    "fp8_autocast",
    "update_collections",
    "get_delayed_scaling",
    "MeshResource",
    "MajorShardingType",
    "ShardingResource",
    "ShardingType",
    "flax",
    "praxis",
98
]