Unverified Commit c6a9e261 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch][Jax] Add warning for missing SOs if both frameworks are installed (#1834)



* Add warning for multi framework case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarAlp Dener <adener@nvidia.com>
parent d5d78333
......@@ -6,17 +6,54 @@
# pylint: disable=unused-import
import os
from importlib import metadata
import transformer_engine.common
try:
from . import pytorch
except (ImportError, FileNotFoundError):
except ImportError:
pass
except FileNotFoundError as e:
if "Could not find shared object file" in str(e):
if os.getenv("NVTE_FRAMEWORK") is None:
# If we got here, we could import `torch` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.jax` on a system that
# also has a PyTorch installation. In order to enable that use case, we issue a warning here
# about the missing PyTorch extension in case the user hasn't set NVTE_FRAMEWORK.
import warnings
warnings.warn(
"Detected a PyTorch installation but could not find the shared object file for the "
"Transformer Engine PyTorch extension library. If this is not intentional, please "
"reinstall Transformer Engine with `pip install transformer_engine[pytorch]` or "
"build from source with `NVTE_FRAMEWORK=pytorch`.",
category=RuntimeWarning,
)
elif os.getenv("NVTE_FRAMEWORK") in ("pytorch", "all"):
raise e
try:
from . import jax
except (ImportError, FileNotFoundError):
except ImportError:
pass
except FileNotFoundError as e:
if "Could not find shared object file" in str(e):
if os.getenv("NVTE_FRAMEWORK") is None:
# If we got here, we could import `jax` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.pytorch` on a system
# that also has a Jax installation. In order to enable that use case, we issue a warning here
# about the missing Jax extension in case the user hasn't set NVTE_FRAMEWORK.
import warnings
warnings.warn(
"Detected a Jax installation but could not find the shared object file for the "
"Transformer Engine Jax extension library. If this is not intentional, please "
"reinstall Transformer Engine with `pip install transformer_engine[jax]` or "
"build from source with `NVTE_FRAMEWORK=jax`.",
category=RuntimeWarning,
)
elif os.getenv("NVTE_FRAMEWORK") in ("jax", "all"):
raise e
__version__ = str(metadata.version("transformer_engine"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment