Unverified Commit 30e30811 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix multi-framework runtime lib loading (#1825)



* Fix single FW build with multi FW available
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Some fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* sug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9627b073
...@@ -11,12 +11,12 @@ import transformer_engine.common ...@@ -11,12 +11,12 @@ import transformer_engine.common
try: try:
from . import pytorch from . import pytorch
except ImportError as e: except (ImportError, FileNotFoundError):
pass pass
try: try:
from . import jax from . import jax
except ImportError as e: except (ImportError, FileNotFoundError):
pass pass
__version__ = str(metadata.version("transformer_engine")) __version__ = str(metadata.version("transformer_engine"))
...@@ -108,9 +108,10 @@ def _get_shared_object_file(library: str) -> Path: ...@@ -108,9 +108,10 @@ def _get_shared_object_file(library: str) -> Path:
# Case 1: Typical user workflow: Both locations are the same, return any result. # Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir: if te_install_dir == site_packages_dir:
assert ( if so_path_in_install_dir is None:
so_path_in_install_dir is not None raise FileNotFoundError(
), f"Could not find shared object file for Transformer Engine {library} lib." f"Could not find shared object file for Transformer Engine {library} lib."
)
return so_path_in_install_dir return so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result. # Case 2: ERR! Both locations are different but returned a valid result.
...@@ -118,8 +119,7 @@ def _get_shared_object_file(library: str) -> Path: ...@@ -118,8 +119,7 @@ def _get_shared_object_file(library: str) -> Path:
# editable builds. In case developers are executing inside a TE directory via # editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object # an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic. # file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None: assert so_path_in_install_dir is None or so_path_in_default_dir is None, (
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and" f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed" f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to" f" here {so_path_in_install_dir} or change the working directory to"
...@@ -134,7 +134,9 @@ def _get_shared_object_file(library: str) -> Path: ...@@ -134,7 +134,9 @@ def _get_shared_object_file(library: str) -> Path:
if so_path_in_default_dir is not None: if so_path_in_default_dir is not None:
return so_path_in_default_dir return so_path_in_default_dir
raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.") raise FileNotFoundError(
f"Could not find shared object file for Transformer Engine {library} lib."
)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -198,6 +200,7 @@ def load_framework_extension(framework: str): ...@@ -198,6 +200,7 @@ def load_framework_extension(framework: str):
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _get_sys_extension(): def _get_sys_extension():
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
extension = "so" extension = "so"
elif system == "Darwin": elif system == "Darwin":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment