Unverified Commit a3e6ed80 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix installation from PyPI wheels after a source install (#1526)



* Fix wheel install after src install
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix JAX imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* switch order of dirs for finding so
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use existing dir src build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 547d8dd8
...@@ -94,7 +94,7 @@ class CMakeExtension(setuptools.Extension): ...@@ -94,7 +94,7 @@ class CMakeExtension(setuptools.Extension):
print(f"Time for build_ext: {total_time:.2f} seconds") print(f"Time for build_ext: {total_time:.2f} seconds")
def get_build_ext(extension_cls: Type[setuptools.Extension]): def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False):
class _CMakeBuildExtension(extension_cls): class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules""" """Setuptools command with support for CMake extension modules"""
...@@ -130,7 +130,12 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]): ...@@ -130,7 +130,12 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
self.extensions = all_extensions self.extensions = all_extensions
# Ensure that binaries are not in global package space. # Ensure that binaries are not in global package space.
target_dir = install_dir / "transformer_engine" lib_dir = (
"wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib
else ""
)
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True) target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"): for ext in Path(self.build_lib).glob("*.so"):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Shared functions for the encoder tests""" """Shared functions for the encoder tests"""
from functools import lru_cache from functools import lru_cache
from transformer_engine.transformer_engine_jax import get_device_compute_capability from transformer_engine_jax import get_device_compute_capability
@lru_cache @lru_cache
......
...@@ -4,8 +4,6 @@ extension-pkg-whitelist=flash_attn_2_cuda, ...@@ -4,8 +4,6 @@ extension-pkg-whitelist=flash_attn_2_cuda,
transformer_engine_torch, transformer_engine_torch,
transformer_engine_jax transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals, disable=too-many-locals,
too-few-public-methods, too-few-public-methods,
too-many-public-methods, too-many-public-methods,
......
...@@ -6,7 +6,9 @@ import os ...@@ -6,7 +6,9 @@ import os
import jax import jax
import pytest import pytest
from transformer_engine.transformer_engine_jax import get_device_compute_capability
import transformer_engine.jax
from transformer_engine_jax import get_device_compute_capability
@pytest.fixture(autouse=True, scope="function") @pytest.fixture(autouse=True, scope="function")
......
...@@ -38,7 +38,7 @@ from transformer_engine.jax.attention import ( ...@@ -38,7 +38,7 @@ from transformer_engine.jax.attention import (
ReorderStrategy, ReorderStrategy,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import ( from transformer_engine_jax import (
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
get_cudnn_version, get_cudnn_version,
) )
......
...@@ -83,6 +83,13 @@ def _load_library(): ...@@ -83,6 +83,13 @@ def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}" so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists(): if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}" so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}" assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging import logging
import importlib
import importlib.util
import ctypes import ctypes
from importlib.metadata import version from importlib.metadata import version
...@@ -49,13 +52,20 @@ def _load_library(): ...@@ -49,13 +52,20 @@ def _load_library():
so_dir = get_te_path() / "transformer_engine" so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration: except StopIteration:
so_dir = get_te_path() try:
so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
_TE_JAX_LIB_CTYPES = _load_library() _load_library()
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME from .fp8 import NVTE_FP8_COLLECTION_NAME
......
...@@ -13,11 +13,11 @@ import jax ...@@ -13,11 +13,11 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.linen import make_attention_mask from flax.linen import make_attention_mask
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format from transformer_engine_jax import nvte_get_qkv_format
from . import cpp_extensions as tex from . import cpp_extensions as tex
......
...@@ -13,8 +13,8 @@ from jax.interpreters.mlir import ir ...@@ -13,8 +13,8 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi from jax import ffi
from transformer_engine import transformer_engine_jax import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type from transformer_engine_jax import NVTE_Activation_Type
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
......
...@@ -17,6 +17,9 @@ from jax.interpreters.mlir import ir ...@@ -17,6 +17,9 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi from jax import ffi
import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
...@@ -26,9 +29,6 @@ from transformer_engine.jax.attention import ( ...@@ -26,9 +29,6 @@ from transformer_engine.jax.attention import (
SequenceDescriptor, SequenceDescriptor,
) )
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
......
...@@ -7,7 +7,7 @@ from enum import IntEnum ...@@ -7,7 +7,7 @@ from enum import IntEnum
import jax import jax
from jax.interpreters import mlir from jax.interpreters import mlir
from transformer_engine import transformer_engine_jax import transformer_engine_jax
from .misc import is_ffi_enabled from .misc import is_ffi_enabled
......
...@@ -15,8 +15,8 @@ import jax.numpy as jnp ...@@ -15,8 +15,8 @@ import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type from jax.interpreters.mlir import dtype_to_ir_type
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
from transformer_engine import transformer_engine_jax import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec from ..sharding import get_padded_spec as te_get_padded_spec
......
...@@ -15,7 +15,7 @@ from jax.interpreters.mlir import ir ...@@ -15,7 +15,7 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi from jax import ffi
from transformer_engine import transformer_engine_jax import transformer_engine_jax
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
......
...@@ -11,8 +11,8 @@ from jax.interpreters.mlir import ir ...@@ -11,8 +11,8 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi from jax import ffi
from transformer_engine import transformer_engine_jax import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
......
...@@ -14,7 +14,7 @@ from jax.interpreters.mlir import ir ...@@ -14,7 +14,7 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi from jax import ffi
from transformer_engine import transformer_engine_jax import transformer_engine_jax
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
......
...@@ -13,8 +13,8 @@ from jax.interpreters.mlir import ir ...@@ -13,8 +13,8 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi from jax import ffi
from transformer_engine import transformer_engine_jax import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
......
...@@ -14,9 +14,9 @@ import jax.numpy as jnp ...@@ -14,9 +14,9 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import fp8_ops from flax.linen import fp8_ops
from transformer_engine.transformer_engine_jax import DType from transformer_engine_jax import DType
from transformer_engine.transformer_engine_jax import get_cublasLt_version from transformer_engine_jax import get_cublasLt_version
from transformer_engine.transformer_engine_jax import ( from transformer_engine_jax import (
get_cuda_version, get_cuda_version,
get_device_compute_capability, get_device_compute_capability,
) )
......
...@@ -37,7 +37,7 @@ install_and_import("pybind11") ...@@ -37,7 +37,7 @@ install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1" os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension) CMakeBuildExtension = get_build_ext(BuildExtension, True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -62,8 +62,12 @@ def _load_library(): ...@@ -62,8 +62,12 @@ def _load_library():
so_dir = get_te_path() / "transformer_engine" so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration: except StopIteration:
so_dir = get_te_path() try:
so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path) spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec) solib = importlib.util.module_from_spec(spec)
......
...@@ -35,7 +35,7 @@ from build_tools.pytorch import setup_pytorch_extension ...@@ -35,7 +35,7 @@ from build_tools.pytorch import setup_pytorch_extension
os.environ["NVTE_PROJECT_BUILDING"] = "1" os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension) CMakeBuildExtension = get_build_ext(BuildExtension, True)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment