Unverified Commit d4f6d929 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix miscellaneous bugs during library loading (#1788)



* Cleanup runtime library loading
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better comments and logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix catching stray builds
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix missing fw case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor grammar
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix duplicate SO for editable installs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better comment for build ext
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve error msg
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7b49cc60
......@@ -130,18 +130,24 @@ def get_build_ext(
super().run()
self.extensions = all_extensions
# Ensure that binaries are not in global package space.
# Ensure that shared objects files for source and PyPI installations live
# in separate directories to avoid conflicts during install and runtime.
lib_dir = (
"wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
else ""
)
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
# Ensure that binaries are not in global package space.
# For editable/inplace builds this is not a concern as
# the SOs will be in a local directory anyway.
if not self.inplace:
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
def build_extensions(self):
# For core lib + JAX install, fix build_ext from pybind11.setup_helpers
......
......@@ -144,7 +144,6 @@ if __name__ == "__main__":
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
ext_modules = []
cmdclass = {}
package_data = {}
include_package_data = False
setup_requires = []
......@@ -156,7 +155,6 @@ if __name__ == "__main__":
else:
setup_requires, install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
package_data = {"": ["VERSION.txt"]}
include_package_data = True
extras_require = {"test": test_requires}
......
......@@ -11,12 +11,12 @@ import transformer_engine.common
try:
from . import pytorch
except (ImportError, StopIteration) as e:
except ImportError as e:
pass
try:
from . import jax
except (ImportError, StopIteration) as e:
except ImportError as e:
pass
__version__ = str(metadata.version("transformer_engine"))
......@@ -9,28 +9,193 @@ import glob
import sysconfig
import subprocess
import ctypes
import logging
import os
import platform
import importlib
import functools
from pathlib import Path
from importlib.metadata import version, metadata, PackageNotFoundError
def is_package_installed(package):
"""Checks if a pip package is installed."""
return (
subprocess.run(
[sys.executable, "-m", "pip", "show", package], capture_output=True, check=False
).returncode
== 0
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package):
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
# if the python package is installed via pip, and not
# if it's importable in the current directory due to
# the presence of the shared library module.
try:
metadata(package)
except PackageNotFoundError:
return False
return True
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if not te_path.exists() or not (te_path / "transformer_engine").exists():
return None
files = []
search_paths = (
te_path,
te_path / "transformer_engine",
te_path / "transformer_engine/wheel_lib",
te_path / "wheel_lib",
)
# Search.
for dirname, _, names in os.walk(te_path):
if Path(dirname) in search_paths:
for name in names:
if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"):
files.append(Path(dirname, name))
if len(files) == 0:
return None
if len(files) == 1:
return files[0]
raise RuntimeError(f"Multiple files found: {files}")
@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
"""
# Check provided input and determine the correct prefix for .so.
assert library in ("core", "torch", "jax"), f"Unsupported TE library {library}."
if library == "core":
so_prefix = "libtransformer_engine"
else:
so_prefix = f"transformer_engine_{library}"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix)
# Check default python package install location in system.
site_packages_dir = Path(sysconfig.get_paths()["purelib"])
so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix)
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
assert (
so_path_in_install_dir is not None
), f"Could not find shared object file for Transformer Engine {library} lib."
return so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
)
# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir
def get_te_path() -> Path:
"""Find Transformer Engine install path using pip"""
return Path(importlib.metadata.distribution("transformer_engine").locate_file("").resolve())
raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str):
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
"""
# Supported frameworks.
assert framework in ("jax", "torch"), f"Unsupported framework {framework}"
# Name of the framework extension library.
module_name = f"transformer_engine_{framework}"
# Name of the pip extra dependency for framework extensions from PyPI.
extra_dep_name = module_name
if framework == "torch":
extra_dep_name = "pytorch"
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
# After all checks are completed, load the shared object file.
spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
@functools.lru_cache(maxsize=None)
def _get_sys_extension():
system = platform.system()
if system == "Linux":
......@@ -45,6 +210,7 @@ def _get_sys_extension():
return extension
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
......@@ -82,6 +248,7 @@ def _nvidia_cudart_include_dir():
return str(include_dir) if include_dir.exists() else ""
@functools.lru_cache(maxsize=None)
def _load_cudnn():
"""Load CUDNN shared library."""
......@@ -109,24 +276,7 @@ def _load_cudnn():
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
......@@ -158,12 +308,18 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_library()
_TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
......
......@@ -20,66 +20,17 @@ All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order
# pylint: disable=wrong-import-position
import logging
import importlib
import importlib.util
from importlib.metadata import version
import sys
# This unused import is needed because the top level `transformer_engine/__init__.py`
# file catches an `ImportError` as a guard for cases where the given framework's
# extensions are not available.
import jax
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension
from transformer_engine.common import load_framework_extension
load_framework_extension("jax")
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_jax"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
logging.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'",
module_name,
)
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
_load_library()
from . import flax
from . import quantize
......
......@@ -4,22 +4,14 @@
"""Transformer Engine bindings for pyTorch"""
# pylint: disable=wrong-import-position,wrong-import-order
# pylint: disable=wrong-import-position
import logging
import functools
import sys
import importlib
import importlib.util
from importlib.metadata import version
from packaging.version import Version as PkgVersion
import torch
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
from transformer_engine.common import load_framework_extension
@functools.lru_cache(maxsize=None)
......@@ -28,57 +20,10 @@ def torch_version() -> tuple[int, ...]:
return PkgVersion(str(torch.__version__)).release
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'",
module_name,
)
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
_load_library()
load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment