Unverified Commit c1417116 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Minor improvements to runtime error checks during library loading (#1837)



minor build improvements
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 62f5c9ee
......@@ -15,8 +15,14 @@ try:
except ImportError:
pass
except FileNotFoundError as e:
if "Could not find shared object file" in str(e):
if os.getenv("NVTE_FRAMEWORK") is None:
if "Could not find shared object file" not in str(e):
raise e # Unexpected error
else:
if os.getenv("NVTE_FRAMEWORK"):
frameworks = os.getenv("NVTE_FRAMEWORK").split(",")
if "pytorch" in frameworks or "all" in frameworks:
raise e
else:
# If we got here, we could import `torch` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.jax` on a system that
# also has a PyTorch installation. In order to enable that use case, we issue a warning here
......@@ -30,16 +36,20 @@ except FileNotFoundError as e:
"build from source with `NVTE_FRAMEWORK=pytorch`.",
category=RuntimeWarning,
)
elif os.getenv("NVTE_FRAMEWORK") in ("pytorch", "all"):
raise e
try:
from . import jax
except ImportError:
pass
except FileNotFoundError as e:
if "Could not find shared object file" in str(e):
if os.getenv("NVTE_FRAMEWORK") is None:
if "Could not find shared object file" not in str(e):
raise e # Unexpected error
else:
if os.getenv("NVTE_FRAMEWORK"):
frameworks = os.getenv("NVTE_FRAMEWORK").split(",")
if "jax" in frameworks or "all" in frameworks:
raise e
else:
# If we got here, we could import `jax` but could not load the framework extension.
# This can happen when a user wants to work only with `transformer_engine.pytorch` on a system
# that also has a Jax installation. In order to enable that use case, we issue a warning here
......@@ -53,7 +63,5 @@ except FileNotFoundError as e:
"build from source with `NVTE_FRAMEWORK=jax`.",
category=RuntimeWarning,
)
elif os.getenv("NVTE_FRAMEWORK") in ("jax", "all"):
raise e
__version__ = str(metadata.version("transformer_engine"))
......@@ -57,10 +57,9 @@ def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
files = []
search_paths = (
te_path,
te_path / "transformer_engine",
te_path / "transformer_engine/wheel_lib",
te_path / "wheel_lib",
te_path, # Editable build.
te_path / "transformer_engine", # Regular source build.
te_path / "transformer_engine/wheel_lib", # PyPI.
)
# Search.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment