Unverified Commit 05eb6deb authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Revert "Import framework submodules lazily (#839)" (#851)

This reverts commit 07291027

.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 50e7a3da
...@@ -3,43 +3,15 @@ ...@@ -3,43 +3,15 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Top level package""" """Top level package"""
import importlib.util
import sys
from types import ModuleType
from ._version import __version__ from ._version import __version__
from . import common from . import common
def _lazy_import(name: str) -> ModuleType:
"""Construct a module that is imported the first time it is used"""
spec = importlib.util.find_spec(name)
loader = importlib.util.LazyLoader(spec.loader)
spec.loader = loader
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
loader.exec_module(module)
return module
# Import framework submodules
# Note: Load module lazily if import fails. This way a useful import
# error will be thrown if the user attempts to access the module.
try: try:
from . import pytorch from . import pytorch
except ImportError: except ImportError as e:
pytorch = _lazy_import("transformer_engine.pytorch") pass
try: try:
from . import jax from . import jax
except ImportError: except ImportError as e:
jax = _lazy_import("transformer_engine.jax") pass
try:
from . import paddle
except ImportError:
paddle = _lazy_import("transformer_engine.paddle")
__all__ = [
"__version__",
"common",
"jax",
"paddle",
"pytorch",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment