__init__.py 2.97 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
5

6
7
# pylint: disable=wrong-import-position,wrong-import-order

8
import logging
9
import ctypes
10
from importlib.metadata import version
11

12
from transformer_engine.common import get_te_path, is_package_installed
13
14
from transformer_engine.common import _get_sys_extension

15
16
_logger = logging.getLogger(__name__)

17
18
19

def _load_library():
    """Load shared library with Transformer Engine C extensions"""
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    module_name = "transformer_engine_jax"

    if is_package_installed(module_name):
        assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
        assert is_package_installed(
            "transformer_engine_cu12"
        ), "Could not find `transformer-engine-cu12`."
        assert (
            version(module_name)
            == version("transformer-engine")
            == version("transformer-engine-cu12")
        ), (
            "TransformerEngine package version mismatch. Found"
            f" {module_name} v{version(module_name)}, transformer-engine"
            f" v{version('transformer-engine')}, and transformer-engine-cu12"
            f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install"
            " transformer-engine[jax]==VERSION'"
        )

    if is_package_installed("transformer-engine-cu12"):
        if not is_package_installed(module_name):
41
            _logger.info(
42
43
44
45
46
                "Could not find package %s. Install transformer-engine using 'pip"
                " install transformer-engine[jax]==VERSION'",
                module_name,
            )

47
48
49
    extension = _get_sys_extension()
    try:
        so_dir = get_te_path() / "transformer_engine"
50
        so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
51
52
    except StopIteration:
        so_dir = get_te_path()
53
        so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
54
55
56
57
58

    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)


_TE_JAX_LIB_CTYPES = _load_library()
59
from . import flax
60
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
61
62
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
63
from .sharding import MajorShardingType, ShardingResource, ShardingType
64

65
66
67
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

68
69
70
MajorShardingType = DeprecatedEnum(
    MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
71
72
73
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
    ShardingResource,
74
75
    "ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
76
77

__all__ = [
78
79
80
81
82
83
84
85
86
87
    "NVTE_FP8_COLLECTION_NAME",
    "fp8_autocast",
    "update_collections",
    "get_delayed_scaling",
    "MeshResource",
    "MajorShardingType",
    "ShardingResource",
    "ShardingType",
    "flax",
    "praxis",
88
]