Unverified Commit f3aca1ee authored by Yang Chen's avatar Yang Chen Committed by GitHub
Browse files

setup correct nvcc version with CUDA_HOME (#15725)


Signed-off-by: default avatarYang Chen <yangche@fb.com>
parent 8dd41d6b
...@@ -201,6 +201,9 @@ class cmake_build_ext(build_ext): ...@@ -201,6 +201,9 @@ class cmake_build_ext(build_ext):
else: else:
# Default build tool to whatever cmake picks. # Default build tool to whatever cmake picks.
build_tool = [] build_tool = []
# Make sure we use the nvcc from CUDA_HOME
if _is_cuda():
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
subprocess.check_call( subprocess.check_call(
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args], ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
cwd=self.build_temp) cwd=self.build_temp)
...@@ -639,11 +642,10 @@ if _is_hip(): ...@@ -639,11 +642,10 @@ if _is_hip():
if _is_cuda(): if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"): if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
# FA3 requires CUDA 12.0 or later # FA3 requires CUDA 12.3 or later
ext_modules.append( ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
# Optional since this doesn't get built (produce an .so file) when # Optional since this doesn't get built (produce an .so file) when
# not targeting a hopper system # not targeting a hopper system
ext_modules.append( ext_modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment