__init__.py 2.54 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Top level package"""
6
7
8

# pylint: disable=unused-import

9
import os
10
11
from importlib import metadata
import transformer_engine.common
12
13
14

try:
    from . import pytorch
15
except ImportError:
16
    pass
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
except FileNotFoundError as e:
    if "Could not find shared object file" in str(e):
        if os.getenv("NVTE_FRAMEWORK") is None:
            # If we got here, we could import `torch` but could not load the framework extension.
            # This can happen when a user wants to work only with `transformer_engine.jax` on a system that
            # also has a PyTorch installation. In order to enable that use case, we issue a warning here
            # about the missing PyTorch extension in case the user hasn't set NVTE_FRAMEWORK.
            import warnings

            warnings.warn(
                "Detected a PyTorch installation but could not find the shared object file for the "
                "Transformer Engine PyTorch extension library. If this is not intentional, please "
                "reinstall Transformer Engine with `pip install transformer_engine[pytorch]` or "
                "build from source with `NVTE_FRAMEWORK=pytorch`.",
                category=RuntimeWarning,
            )
        elif os.getenv("NVTE_FRAMEWORK") in ("pytorch", "all"):
            raise e
35

36
37
try:
    from . import jax
38
except ImportError:
39
    pass
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
except FileNotFoundError as e:
    if "Could not find shared object file" in str(e):
        if os.getenv("NVTE_FRAMEWORK") is None:
            # If we got here, we could import `jax` but could not load the framework extension.
            # This can happen when a user wants to work only with `transformer_engine.pytorch` on a system
            # that also has a Jax installation. In order to enable that use case, we issue a warning here
            # about the missing Jax extension in case the user hasn't set NVTE_FRAMEWORK.
            import warnings

            warnings.warn(
                "Detected a Jax installation but could not find the shared object file for the "
                "Transformer Engine Jax extension library. If this is not intentional, please "
                "reinstall Transformer Engine with `pip install transformer_engine[jax]` or "
                "build from source with `NVTE_FRAMEWORK=jax`.",
                category=RuntimeWarning,
            )
        elif os.getenv("NVTE_FRAMEWORK") in ("jax", "all"):
            raise e
58
59

__version__ = str(metadata.version("transformer_engine"))