Unverified Commit 6fc37bd8 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix sgl-kernel compile for sm80 (#3046)

parent 3d8f1c9b
...@@ -24,6 +24,22 @@ def update_wheel_platform_tag(): ...@@ -24,6 +24,22 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel) old_wheel.rename(new_wheel)
def get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0)
def get_device_sm():
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
return major * 10 + minor
return 0
cuda_version = get_cuda_version()
sm_version = get_device_sm()
cutlass = root / "3rdparty" / "cutlass" cutlass = root / "3rdparty" / "cutlass"
flashinfer = root / "3rdparty" / "flashinfer" flashinfer = root / "3rdparty" / "flashinfer"
include_dirs = [ include_dirs = [
...@@ -42,12 +58,15 @@ nvcc_flags = [ ...@@ -42,12 +58,15 @@ nvcc_flags = [
"-gencode=arch=compute_80,code=sm_80", "-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89", "-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90", "-gencode=arch=compute_90,code=sm_90",
"-gencode=arch=compute_90a,code=sm_90a",
"-std=c++17", "-std=c++17",
"-use_fast_math", "-use_fast_math",
"-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16", "-DFLASHINFER_ENABLE_BF16",
] ]
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
for flag in [ for flag in [
"-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF_CONVERSIONS__",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment