Commit 5101e6bc authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Add ahead of time cython compilation in setup.py (#622)

* [Enhancement] Add Cython support and compiler detection in setup.py

- Introduced a new `CythonExtension` class for building Cython-based extensions, enhancing the build process for Cython projects.
- Implemented functions to detect the Cython compiler and C++ compiler, improving compatibility and user experience.
- Updated the build process to handle Cython extensions alongside CMake extensions, ensuring a seamless integration for users.
- Added caching mechanisms for Cython compilation to optimize build times and reduce unnecessary recompilation.

* [Enhancement] Add Cython dependency and enable CMake extension building

- Added Cython as a required dependency in `pyproject.toml` to support Cython-based extensions.
- Updated `setup.py` to enable building CMake extensions, improving the build process for projects utilizing both Cython and CMake.
- Modified the Cython compiler detection logic to streamline installation instructions for users.
parent 0fd3a3e8
[build-system] [build-system]
requires = [ requires = [
"cmake>=3.26", "cmake>=3.26",
"cython",
"packaging", "packaging",
"setuptools>=61", "setuptools>=61",
"wheel", "wheel",
......
...@@ -6,12 +6,17 @@ from setuptools.command.build_py import build_py ...@@ -6,12 +6,17 @@ from setuptools.command.build_py import build_py
from setuptools.command.sdist import sdist from setuptools.command.sdist import sdist
from setuptools.command.develop import develop from setuptools.command.develop import develop
import distutils.dir_util import distutils.dir_util
from typing import List from typing import List, Optional
import re import re
import tarfile import tarfile
from io import BytesIO from io import BytesIO
from pathlib import Path
import os import os
import sys import sys
import site
import hashlib
import sysconfig
import functools
import urllib.request import urllib.request
from distutils.version import LooseVersion from distutils.version import LooseVersion
import platform import platform
...@@ -19,6 +24,7 @@ import multiprocessing ...@@ -19,6 +24,7 @@ import multiprocessing
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
import importlib import importlib
import logging import logging
import fcntl
# Configure logging with basic settings # Configure logging with basic settings
logging.basicConfig( logging.basicConfig(
...@@ -73,6 +79,11 @@ if not (CUDA_HOME or ROCM_HOME): ...@@ -73,6 +79,11 @@ if not (CUDA_HOME or ROCM_HOME):
assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)." assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)."
def _is_linux_like():
return (sys.platform == "darwin" or sys.platform.startswith("linux") or
sys.platform.startswith("freebsd"))
def get_path(*filepath) -> str: def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath) return os.path.join(ROOT_DIR, *filepath)
...@@ -167,6 +178,67 @@ def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=F ...@@ -167,6 +178,67 @@ def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=F
return version return version
@functools.lru_cache(maxsize=None)
def get_cplus_compiler():
"""Return the path to the default C/C++ compiler.
Returns
-------
out: Optional[str]
The path to the default C/C++ compiler, or None if none was found.
"""
if not _is_linux_like():
return None
env_cxx = os.environ.get("CXX") or os.environ.get("CC")
if env_cxx:
return env_cxx
cc_names = ["g++", "clang++", "c++"]
dirs_in_path = os.get_exec_path()
for cc in cc_names:
for d in dirs_in_path:
cc_path = os.path.join(d, cc)
if os.path.isfile(cc_path) and os.access(cc_path, os.X_OK):
return cc_path
return None
def get_cython_compiler() -> Optional[str]:
"""Return the path to the Cython compiler.
Returns
-------
out: Optional[str]
The path to the Cython compiler, or None if none was found.
"""
cython_names = ["cython", "cython3"]
# Check system PATH
dirs_in_path = list(os.get_exec_path())
# Add user site-packages bin directory
user_base = site.getuserbase()
if user_base:
user_bin = os.path.join(user_base, "bin")
if os.path.exists(user_bin):
dirs_in_path = [user_bin] + dirs_in_path
# If in a virtual environment, add its bin directory
if sys.prefix != sys.base_prefix:
venv_bin = os.path.join(sys.prefix, "bin")
if os.path.exists(venv_bin):
dirs_in_path = [venv_bin] + dirs_in_path
for cython_name in cython_names:
for d in dirs_in_path:
cython_path = os.path.join(d, cython_name)
if os.path.isfile(cython_path) and os.access(cython_path, os.X_OK):
return cython_path
return None
def get_system_info(): def get_system_info():
system = platform.system().lower() system = platform.system().lower()
if system == "linux": if system == "linux":
...@@ -581,7 +653,17 @@ class CMakeExtension(Extension): ...@@ -581,7 +653,17 @@ class CMakeExtension(Extension):
self.sourcedir = os.path.abspath(sourcedir) self.sourcedir = os.path.abspath(sourcedir)
class CMakeBuild(build_ext): class CythonExtension(Extension):
"""
A specialized setuptools Extension class for building a Cython project.
"""
def __init__(self, name, sourcedir=""):
super().__init__(name=name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
class TilelangExtensionBuild(build_ext):
""" """
Custom build_ext command for CMake-based projects. Custom build_ext command for CMake-based projects.
...@@ -603,7 +685,12 @@ class CMakeBuild(build_ext): ...@@ -603,7 +685,12 @@ class CMakeBuild(build_ext):
# Build each extension (of type CMakeExtension) using our custom method. # Build each extension (of type CMakeExtension) using our custom method.
for ext in self.extensions: for ext in self.extensions:
self.build_cmake(ext) if isinstance(ext, CythonExtension):
self.build_cython(ext)
elif isinstance(ext, CMakeExtension):
self.build_cmake(ext)
else:
raise ValueError(f"Unsupported extension type: {type(ext)}")
# To make it works with editable install, # To make it works with editable install,
# we need to copy the lib*.so files to the tilelang/lib directory # we need to copy the lib*.so files to the tilelang/lib directory
...@@ -618,6 +705,96 @@ class CMakeBuild(build_ext): ...@@ -618,6 +705,96 @@ class CMakeBuild(build_ext):
# remove the original file # remove the original file
os.remove(file) os.remove(file)
def build_cython(self, ext):
"""
Build a single Cython-based extension.
:param ext: The extension (an instance of CythonExtension).
"""
cython_compiler = get_cython_compiler()
if not cython_compiler:
logger.info("Cython compiler not found, install it first")
subprocess.check_call(["pip", "install", "cython"])
cython_compiler = get_cython_compiler()
if not cython_compiler:
raise Exception("Cython is not installed, please install it first.")
logger.info(f"Using Cython compiler: {cython_compiler}")
cython_warpper_dir = os.path.join(ext.sourcedir, "tilelang", "jit", "adapter", "cython")
cython_wrapper_path = os.path.join(cython_warpper_dir, "cython_wrapper.pyx")
py_version = f"py{sys.version_info.major}{sys.version_info.minor}"
cache_dir = Path(cython_warpper_dir) / ".cycache" / py_version
os.makedirs(cache_dir, exist_ok=True)
with open(cython_wrapper_path, "r") as f:
cython_wrapper_code = f.read()
source_path = cache_dir / "cython_wrapper.cpp"
library_path = cache_dir / "cython_wrapper.so"
md5_path = cache_dir / "md5.txt"
code_hash = hashlib.sha256(cython_wrapper_code.encode()).hexdigest()
cache_path = cache_dir / f"{code_hash}.so"
lock_file = cache_path.with_suffix('.lock')
# Check if cached version exists and is valid
need_compile = True
if md5_path.exists() and library_path.exists():
with open(md5_path, "r") as f:
cached_hash = f.read().strip()
if cached_hash == code_hash:
logger.info("Cython jit adapter is up to date, no need to compile...")
need_compile = False
else:
logger.info("Cython jit adapter is out of date, need to recompile...")
else:
logger.info("No cached version found for cython jit adapter, need to compile...")
if need_compile:
logger.info("Waiting for lock to compile cython jit adapter...")
with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try:
# After acquiring the lock, check again if the file has been compiled by another process
if md5_path.exists() and library_path.exists():
with open(md5_path, "r") as f:
cached_hash = f.read().strip()
if cached_hash == code_hash:
logger.info(
"Another process has already compiled the file, using it..."
)
need_compile = False
if need_compile:
logger.info("Compiling cython jit adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so"
with open(md5_path, "w") as f:
f.write(code_hash)
# compile the cython_wrapper.pyx file into .cpp
cython = get_cython_compiler()
if cython is None:
raise Exception("Cython is not installed, please install it first.")
os.system(f"{cython} {cython_wrapper_path} --cplus -o {source_path}")
python_include_path = sysconfig.get_path("include")
cc = get_cplus_compiler()
command = f"{cc} -shared -pthread -fPIC -fwrapv -O2 -Wall -fno-strict-aliasing -I{python_include_path} {source_path} -o {temp_path}"
os.system(command)
# rename the temp file to the library file
temp_path.rename(library_path)
except Exception as e:
if 'temp_path' in locals() and temp_path.exists():
temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e
finally:
if lock_file.exists():
lock_file.unlink()
# add the .so file to the sys.path
cache_dir_str = str(cache_dir)
if cache_dir_str not in sys.path:
sys.path.append(cache_dir_str)
def build_cmake(self, ext): def build_cmake(self, ext):
""" """
Build a single CMake-based extension. Build a single CMake-based extension.
...@@ -696,11 +873,14 @@ setup( ...@@ -696,11 +873,14 @@ setup(
install_requires=get_requirements(), install_requires=get_requirements(),
package_data=package_data, package_data=package_data,
include_package_data=False, include_package_data=False,
ext_modules=[CMakeExtension("TileLangCXX", sourcedir=".")], ext_modules=[
CMakeExtension("TileLangCXX", sourcedir="."),
CythonExtension("TileLangCython", sourcedir="."),
],
cmdclass={ cmdclass={
"build_py": TileLangBuilPydCommand, "build_py": TileLangBuilPydCommand,
"sdist": TileLangSdistCommand, "sdist": TileLangSdistCommand,
"build_ext": CMakeBuild, "build_ext": TilelangExtensionBuild,
"develop": TileLangDevelopCommand, "develop": TileLangDevelopCommand,
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment