Unverified Commit d4f6d929 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix miscellaneous bugs during library loading (#1788)



* Cleanup runtime library loading
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better comments and logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix catching stray builds
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix missing fw case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor grammar
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix duplicate SO for editable installs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better comment for build ext
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve error msg
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7b49cc60
...@@ -130,18 +130,24 @@ def get_build_ext( ...@@ -130,18 +130,24 @@ def get_build_ext(
super().run() super().run()
self.extensions = all_extensions self.extensions = all_extensions
# Ensure that binaries are not in global package space. # Ensure that shared objects files for source and PyPI installations live
# in separate directories to avoid conflicts during install and runtime.
lib_dir = ( lib_dir = (
"wheel_lib" "wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
else "" else ""
) )
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"): # Ensure that binaries are not in global package space.
self.copy_file(ext, target_dir) # For editable/inplace builds this is not a concern as
os.remove(ext) # the SOs will be in a local directory anyway.
if not self.inplace:
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
def build_extensions(self): def build_extensions(self):
# For core lib + JAX install, fix build_ext from pybind11.setup_helpers # For core lib + JAX install, fix build_ext from pybind11.setup_helpers
......
...@@ -144,7 +144,6 @@ if __name__ == "__main__": ...@@ -144,7 +144,6 @@ if __name__ == "__main__":
int(os.getenv("NVTE_RELEASE_BUILD", "0")) int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build." ), "NVTE_RELEASE_BUILD env must be set for metapackage build."
ext_modules = [] ext_modules = []
cmdclass = {}
package_data = {} package_data = {}
include_package_data = False include_package_data = False
setup_requires = [] setup_requires = []
...@@ -156,7 +155,6 @@ if __name__ == "__main__": ...@@ -156,7 +155,6 @@ if __name__ == "__main__":
else: else:
setup_requires, install_requires, test_requires = setup_requirements() setup_requires, install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()] ext_modules = [setup_common_extension()]
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
package_data = {"": ["VERSION.txt"]} package_data = {"": ["VERSION.txt"]}
include_package_data = True include_package_data = True
extras_require = {"test": test_requires} extras_require = {"test": test_requires}
......
...@@ -11,12 +11,12 @@ import transformer_engine.common ...@@ -11,12 +11,12 @@ import transformer_engine.common
try: try:
from . import pytorch from . import pytorch
except (ImportError, StopIteration) as e: except ImportError as e:
pass pass
try: try:
from . import jax from . import jax
except (ImportError, StopIteration) as e: except ImportError as e:
pass pass
__version__ = str(metadata.version("transformer_engine")) __version__ = str(metadata.version("transformer_engine"))
...@@ -9,28 +9,193 @@ import glob ...@@ -9,28 +9,193 @@ import glob
import sysconfig import sysconfig
import subprocess import subprocess
import ctypes import ctypes
import logging
import os import os
import platform import platform
import importlib import importlib
import functools import functools
from pathlib import Path from pathlib import Path
from importlib.metadata import version, metadata, PackageNotFoundError
def is_package_installed(package): _logger = logging.getLogger(__name__)
"""Checks if a pip package is installed."""
return (
subprocess.run( @functools.lru_cache(maxsize=None)
[sys.executable, "-m", "pip", "show", package], capture_output=True, check=False def _is_pip_package_installed(package):
).returncode """Check if the given package is installed via pip."""
== 0
# This is needed because we only want to return true
# if the python package is installed via pip, and not
# if it's importable in the current directory due to
# the presence of the shared library module.
try:
metadata(package)
except PackageNotFoundError:
return False
return True
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if not te_path.exists() or not (te_path / "transformer_engine").exists():
return None
files = []
search_paths = (
te_path,
te_path / "transformer_engine",
te_path / "transformer_engine/wheel_lib",
te_path / "wheel_lib",
) )
# Search.
for dirname, _, names in os.walk(te_path):
if Path(dirname) in search_paths:
for name in names:
if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"):
files.append(Path(dirname, name))
if len(files) == 0:
return None
if len(files) == 1:
return files[0]
raise RuntimeError(f"Multiple files found: {files}")
@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
"""
# Check provided input and determine the correct prefix for .so.
assert library in ("core", "torch", "jax"), f"Unsupported TE library {library}."
if library == "core":
so_prefix = "libtransformer_engine"
else:
so_prefix = f"transformer_engine_{library}"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix)
# Check default python package install location in system.
site_packages_dir = Path(sysconfig.get_paths()["purelib"])
so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix)
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
assert (
so_path_in_install_dir is not None
), f"Could not find shared object file for Transformer Engine {library} lib."
return so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
)
# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir
def get_te_path() -> Path: raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")
"""Find Transformer Engine install path using pip"""
return Path(importlib.metadata.distribution("transformer_engine").locate_file("").resolve())
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str):
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
"""
# Supported frameworks.
assert framework in ("jax", "torch"), f"Unsupported framework {framework}"
# Name of the framework extension library.
module_name = f"transformer_engine_{framework}"
# Name of the pip extra dependency for framework extensions from PyPI.
extra_dep_name = module_name
if framework == "torch":
extra_dep_name = "pytorch"
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
# After all checks are completed, load the shared object file.
spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
@functools.lru_cache(maxsize=None)
def _get_sys_extension(): def _get_sys_extension():
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
...@@ -45,6 +210,7 @@ def _get_sys_extension(): ...@@ -45,6 +210,7 @@ def _get_sys_extension():
return extension return extension
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str): def _load_nvidia_cuda_library(lib_name: str):
""" """
Attempts to load shared object file installed via pip. Attempts to load shared object file installed via pip.
...@@ -82,6 +248,7 @@ def _nvidia_cudart_include_dir(): ...@@ -82,6 +248,7 @@ def _nvidia_cudart_include_dir():
return str(include_dir) if include_dir.exists() else "" return str(include_dir) if include_dir.exists() else ""
@functools.lru_cache(maxsize=None)
def _load_cudnn(): def _load_cudnn():
"""Load CUDNN shared library.""" """Load CUDNN shared library."""
...@@ -109,24 +276,7 @@ def _load_cudnn(): ...@@ -109,24 +276,7 @@ def _load_cudnn():
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library(): @functools.lru_cache(maxsize=None)
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
def _load_nvrtc(): def _load_nvrtc():
"""Load NVRTC shared library.""" """Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
...@@ -158,12 +308,18 @@ def _load_nvrtc(): ...@@ -158,12 +308,18 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn() _CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() _NVRTC_LIB_CTYPES = _load_nvrtc()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_library() _TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels. # Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir(): if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
......
...@@ -20,66 +20,17 @@ All operations are designed to work seamlessly with JAX's functional programming ...@@ -20,66 +20,17 @@ All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation. model and support automatic differentiation.
""" """
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position
import logging # This unused import is needed because the top level `transformer_engine/__init__.py`
import importlib # file catches an `ImportError` as a guard for cases where the given framework's
import importlib.util # extensions are not available.
from importlib.metadata import version import jax
import sys
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import load_framework_extension
from transformer_engine.common import _get_sys_extension
load_framework_extension("jax")
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_jax"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
logging.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'",
module_name,
)
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
_load_library()
from . import flax from . import flax
from . import quantize from . import quantize
......
...@@ -4,22 +4,14 @@ ...@@ -4,22 +4,14 @@
"""Transformer Engine bindings for pyTorch""" """Transformer Engine bindings for pyTorch"""
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position
import logging
import functools import functools
import sys
import importlib
import importlib.util
from importlib.metadata import version
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import load_framework_extension
from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -28,57 +20,10 @@ def torch_version() -> tuple[int, ...]: ...@@ -28,57 +20,10 @@ def torch_version() -> tuple[int, ...]:
return PkgVersion(str(torch.__version__)).release return PkgVersion(str(torch.__version__)).release
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'",
module_name,
)
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
_load_library() load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP from transformer_engine.pytorch.module import LayerNormMLP
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment