__init__.py 1.17 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Top level package"""
6
7
8
9
import importlib.util
import sys
from types import ModuleType

10
from ._version import __version__
Przemek Tredak's avatar
Przemek Tredak committed
11
from . import common
12

13
14
15
16
17
18
19
20
21
22
23
24
25
def _lazy_import(name: str) -> ModuleType:
    """Construct a module that is imported the first time it is used"""
    spec = importlib.util.find_spec(name)
    loader = importlib.util.LazyLoader(spec.loader)
    spec.loader = loader
    module = importlib.util.module_from_spec(spec)
    sys.modules[name] = module
    loader.exec_module(module)
    return module

# Import framework submodules
# Note: Load module lazily if import fails. This way a useful import
# error will be thrown if the user attempts to access the module.
26
27
try:
    from . import pytorch
28
29
except ImportError:
    pytorch = _lazy_import("transformer_engine.pytorch")
30
31
try:
    from . import jax
32
33
34
35
36
37
38
39
40
41
42
43
44
45
except ImportError:
    jax = _lazy_import("transformer_engine.jax")
try:
    from . import paddle
except ImportError:
    paddle = _lazy_import("transformer_engine.paddle")

__all__ = [
    "__version__",
    "common",
    "jax",
    "paddle",
    "pytorch",
]