Unverified Commit 07291027 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Import framework submodules lazily (#839)



* Import frameworks lazily
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only load modules lazily after an import error

Pylint doesn't handle lazy loading gracefully.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 87e4d6c3
...@@ -3,15 +3,43 @@ ...@@ -3,15 +3,43 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Top level package""" """Top level package"""
import importlib.util
import sys
from types import ModuleType
from ._version import __version__ from ._version import __version__
from . import common from . import common
def _lazy_import(name: str) -> ModuleType:
"""Construct a module that is imported the first time it is used"""
spec = importlib.util.find_spec(name)
loader = importlib.util.LazyLoader(spec.loader)
spec.loader = loader
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
loader.exec_module(module)
return module
# Import framework submodules
# Note: Load module lazily if import fails. This way a useful import
# error will be thrown if the user attempts to access the module.
try: try:
from . import pytorch from . import pytorch
except ImportError as e: except ImportError:
pass pytorch = _lazy_import("transformer_engine.pytorch")
try: try:
from . import jax from . import jax
except ImportError as e: except ImportError:
pass jax = _lazy_import("transformer_engine.jax")
try:
from . import paddle
except ImportError:
paddle = _lazy_import("transformer_engine.paddle")
__all__ = [
"__version__",
"common",
"jax",
"paddle",
"pytorch",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment