Unverified Commit a3e6ed80 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix installation from PyPI wheels after a source install (#1526)



* Fix wheel install after src install
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix JAX imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* switch order of dirs for finding so
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use existing dir src build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 547d8dd8
......@@ -94,7 +94,7 @@ class CMakeExtension(setuptools.Extension):
print(f"Time for build_ext: {total_time:.2f} seconds")
def get_build_ext(extension_cls: Type[setuptools.Extension]):
def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False):
class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules"""
......@@ -130,7 +130,12 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
self.extensions = all_extensions
# Ensure that binaries are not in global package space.
target_dir = install_dir / "transformer_engine"
lib_dir = (
"wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib
else ""
)
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
......
......@@ -4,7 +4,7 @@
"""Shared functions for the encoder tests"""
from functools import lru_cache
from transformer_engine.transformer_engine_jax import get_device_compute_capability
from transformer_engine_jax import get_device_compute_capability
@lru_cache
......
......@@ -4,8 +4,6 @@ extension-pkg-whitelist=flash_attn_2_cuda,
transformer_engine_torch,
transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals,
too-few-public-methods,
too-many-public-methods,
......
......@@ -6,7 +6,9 @@ import os
import jax
import pytest
from transformer_engine.transformer_engine_jax import get_device_compute_capability
import transformer_engine.jax
from transformer_engine_jax import get_device_compute_capability
@pytest.fixture(autouse=True, scope="function")
......
......@@ -38,7 +38,7 @@ from transformer_engine.jax.attention import (
ReorderStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
from transformer_engine_jax import (
NVTE_Fused_Attn_Backend,
get_cudnn_version,
)
......
......@@ -83,6 +83,13 @@ def _load_library():
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
......
......@@ -5,7 +5,10 @@
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging
import importlib
import importlib.util
import ctypes
from importlib.metadata import version
......@@ -49,13 +52,20 @@ def _load_library():
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
_TE_JAX_LIB_CTYPES = _load_library()
_load_library()
from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME
......
......@@ -13,11 +13,11 @@ import jax
import jax.numpy as jnp
from flax.linen import make_attention_mask
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format
from . import cpp_extensions as tex
......
......@@ -13,8 +13,8 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
......
......@@ -17,6 +17,9 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
......@@ -26,9 +29,6 @@ from transformer_engine.jax.attention import (
SequenceDescriptor,
)
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
......
......@@ -7,7 +7,7 @@ from enum import IntEnum
import jax
from jax.interpreters import mlir
from transformer_engine import transformer_engine_jax
import transformer_engine_jax
from .misc import is_ffi_enabled
......
......@@ -15,8 +15,8 @@ import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type
from transformer_engine.transformer_engine_jax import DType as TEDType
from transformer_engine import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
......
......@@ -15,7 +15,7 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
from transformer_engine import transformer_engine_jax
import transformer_engine_jax
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
......
......@@ -11,8 +11,8 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
......
......@@ -14,7 +14,7 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
from transformer_engine import transformer_engine_jax
import transformer_engine_jax
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
......
......@@ -13,8 +13,8 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
......
......@@ -14,9 +14,9 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import fp8_ops
from transformer_engine.transformer_engine_jax import DType
from transformer_engine.transformer_engine_jax import get_cublasLt_version
from transformer_engine.transformer_engine_jax import (
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
......
......@@ -37,7 +37,7 @@ install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension)
CMakeBuildExtension = get_build_ext(BuildExtension, True)
if __name__ == "__main__":
......
......@@ -62,8 +62,12 @@ def _load_library():
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
......
......@@ -35,7 +35,7 @@ from build_tools.pytorch import setup_pytorch_extension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension)
CMakeBuildExtension = get_build_ext(BuildExtension, True)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment