Unverified Commit 07a22cbb authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

use env variable to control the build conf on the CPU build node (#3080)

parent 3d0bfa3e
...@@ -11,6 +11,9 @@ docker run --rm \ ...@@ -11,6 +11,9 @@ docker run --rm \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
export CUDA_VERSION=${CUDA_VERSION} && \ export CUDA_VERSION=${CUDA_VERSION} && \
export SGL_KERNEL_ENABLE_BF16=1 && \
export SGL_KERNEL_ENABLE_FP8=1 && \
export SGL_KERNEL_ENABLE_SM90A=1 && \
mkdir -p /usr/lib/x86_64-linux-gnu/ && \ mkdir -p /usr/lib/x86_64-linux-gnu/ && \
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
cd /sgl-kernel && \ cd /sgl-kernel && \
......
import os
from pathlib import Path from pathlib import Path
import torch import torch
from setuptools import find_packages, setup from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from version import __version__
root = Path(__file__).parent.resolve() root = Path(__file__).parent.resolve()
def update_wheel_platform_tag(): def _update_wheel_platform_tag():
wheel_dir = Path("dist") wheel_dir = Path("dist")
if wheel_dir.exists() and wheel_dir.is_dir(): if wheel_dir.exists() and wheel_dir.is_dir():
old_wheel = next(wheel_dir.glob("*.whl")) old_wheel = next(wheel_dir.glob("*.whl"))
...@@ -18,21 +18,25 @@ def update_wheel_platform_tag(): ...@@ -18,21 +18,25 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel) old_wheel.rename(new_wheel)
def get_cuda_version(): def _get_cuda_version():
if torch.version.cuda: if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split("."))) return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0) return (0, 0)
def get_device_sm(): def _get_device_sm():
if torch.cuda.is_available(): if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
return major * 10 + minor return major * 10 + minor
return 0 return 0
cuda_version = get_cuda_version() def _get_version():
sm_version = get_device_sm() with open(root / "pyproject.toml") as f:
for line in f:
if line.startswith("version"):
return line.split("=")[1].strip().strip('"')
cutlass = root / "3rdparty" / "cutlass" cutlass = root / "3rdparty" / "cutlass"
flashinfer = root / "3rdparty" / "flashinfer" flashinfer = root / "3rdparty" / "flashinfer"
...@@ -58,19 +62,39 @@ nvcc_flags = [ ...@@ -58,19 +62,39 @@ nvcc_flags = [
"-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_F16",
] ]
if cuda_version >= (12, 0) and sm_version >= 90: enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
if sm_version >= 90: cuda_version = _get_cuda_version()
nvcc_flags.extend( sm_version = _get_device_sm()
[
"-DFLASHINFER_ENABLE_FP8", if torch.cuda.is_available():
"-DFLASHINFER_ENABLE_FP8_E4M3", if cuda_version >= (12, 0) and sm_version >= 90:
"-DFLASHINFER_ENABLE_FP8_E5M2", nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
] if sm_version >= 90:
) nvcc_flags.extend(
if sm_version >= 80: [
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") "-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
if sm_version >= 80:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else:
# compilation environment without GPU
if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if enable_fp8:
nvcc_flags.extend(
[
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
if enable_bf16:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
for flag in [ for flag in [
"-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_OPERATORS__",
...@@ -82,6 +106,7 @@ for flag in [ ...@@ -82,6 +106,7 @@ for flag in [
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag) torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
except ValueError: except ValueError:
pass pass
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"] libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
...@@ -116,11 +141,11 @@ ext_modules = [ ...@@ -116,11 +141,11 @@ ext_modules = [
setup( setup(
name="sgl-kernel", name="sgl-kernel",
version=__version__, version=_get_version(),
packages=find_packages(), packages=find_packages(),
package_dir={"": "src"}, package_dir={"": "src"},
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}, cmdclass={"build_ext": BuildExtension},
) )
update_wheel_platform_tag() _update_wheel_platform_tag()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment