Unverified Commit 160424a9 authored by Seungmin Kim's avatar Seungmin Kim Committed by GitHub
Browse files

[Bugfix] Fix CUDA compatibility path setting for both datacenter and consumer NVIDIA GPUs (#33992)


Signed-off-by: default avatarSeungmin Kim <8457324+ehfd@users.noreply.github.com>
Signed-off-by: default avatarAndrew Mello <19512127+88plug@users.noreply.github.com>
Co-authored-by: default avatar88plug <19512127+88plug@users.noreply.github.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 9511a3f8
...@@ -132,8 +132,10 @@ ENV UV_LINK_MODE=copy ...@@ -132,8 +132,10 @@ ENV UV_LINK_MODE=copy
# Verify GCC version # Verify GCC version
RUN gcc --version RUN gcc --version
# Ensure CUDA compatibility library is loaded # Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1'
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig # Only needed for datacenter/professional GPUs with older drivers.
# See: https://docs.nvidia.com/deploy/cuda-compatibility/
ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0
# ============================================================ # ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW # SLOW-CHANGING DEPENDENCIES BELOW
...@@ -560,8 +562,10 @@ ENV UV_HTTP_TIMEOUT=500 ...@@ -560,8 +562,10 @@ ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match" ENV UV_INDEX_STRATEGY="unsafe-best-match"
ENV UV_LINK_MODE=copy ENV UV_LINK_MODE=copy
# Ensure CUDA compatibility library is loaded # Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1'
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig # Only needed for datacenter/professional GPUs with older drivers.
# See: https://docs.nvidia.com/deploy/cuda-compatibility/
ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0
# ============================================================ # ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW # SLOW-CHANGING DEPENDENCIES BELOW
......
...@@ -297,6 +297,23 @@ You can add any other [engine-args](https://docs.vllm.ai/en/latest/configuration ...@@ -297,6 +297,23 @@ You can add any other [engine-args](https://docs.vllm.ai/en/latest/configuration
RUN uv pip install --system git+https://github.com/huggingface/transformers.git RUN uv pip install --system git+https://github.com/huggingface/transformers.git
``` ```
#### Running on Systems with Older CUDA Drivers
vLLM's Docker image comes with [CUDA compatibility libraries](https://docs.nvidia.com/deploy/cuda-compatibility/index.html) pre-installed. This allows you to run vLLM on systems with NVIDIA drivers that are older than the CUDA Toolkit version used in the image, but only supports select professional and datacenter NVIDIA GPUs.
To enable this feature, set the `VLLM_ENABLE_CUDA_COMPATIBILITY` environment variable to `1` or `true` when running the container:
```bash
docker run --runtime nvidia --gpus all \
-v ~/.cache/huggingface:/root/.cache/huggingface \
-p 8000:8000 \
--env "HF_TOKEN=<secret>" \
--env "VLLM_ENABLE_CUDA_COMPATIBILITY=1" \
vllm/vllm-openai <args...>
```
This will automatically configure `LD_LIBRARY_PATH` to point to the compatibility libraries before loading PyTorch and other dependencies.
# --8<-- [end:pre-built-images] # --8<-- [end:pre-built-images]
# --8<-- [start:build-image-from-source] # --8<-- [start:build-image-from-source]
......
...@@ -318,7 +318,32 @@ This indicates vLLM failed to initialize the NCCL communicator, possibly due to ...@@ -318,7 +318,32 @@ This indicates vLLM failed to initialize the NCCL communicator, possibly due to
## CUDA error: the provided PTX was compiled with an unsupported toolchain ## CUDA error: the provided PTX was compiled with an unsupported toolchain
If you see an error like `RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain.`, it means that the CUDA PTX in vLLM's wheels was compiled with a toolchain unsupported by your system. The released vLLM wheels have to be compiled with a specific version of CUDA toolkit, and the compiled code might fail to run on lower versions of CUDA drivers. Read [cuda compatibility](https://docs.nvidia.com/deploy/cuda-compatibility/) for more details. The solution is to install `cuda-compat` package from your package manager. For example, on Ubuntu, you can run `sudo apt-get install cuda-compat-12-9`, and then add `export LD_LIBRARY_PATH=/usr/local/cuda-12.9/compat:$LD_LIBRARY_PATH` to your `.bashrc` file. When successfully installed, you should see that the output of `nvidia-smi` will show `CUDA Version: 12.9`. Note that we use CUDA 12.9 as an example here, you may want to install a higher version of cuda-compat package in case vLLM's default CUDA version goes higher. If you see an error like `RuntimeError: CUDA error: the provided PTX was compiled with an unsupported toolchain`, it means that the CUDA PTX in vLLM's wheels was compiled with a toolchain unsupported by your system. This section also applies if you get the error `RuntimeError: The NVIDIA driver on your system is too old`.
The released vLLM wheels are compiled with a specific version of CUDA toolkit, and the compiled code might fail to run on lower versions of CUDA drivers. Read [CUDA compatibility](https://docs.nvidia.com/deploy/cuda-compatibility/) for more details. **This is only supported on select professional and datacenter NVIDIA GPUs.**
If you are using the vLLM official Docker image, you can solve this by adding `-e VLLM_ENABLE_CUDA_COMPATIBILITY=1` to your `docker run` command. This will enable the pre-installed CUDA forward compatibility libraries.
If you are running vLLM outside of Docker, the solution is to install the `cuda-compat` package from your package manager with the [CUDA repository](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/) enabled. For example, on Ubuntu, you can run `sudo apt-get install cuda-compat-12-9`, and then set `export VLLM_ENABLE_CUDA_COMPATIBILITY=1` and `export VLLM_CUDA_COMPATIBILITY_PATH="/usr/local/cuda-12.9/compat"`.
On Conda, you can install the `conda-forge::cuda-compat` package (e.g., `conda install -c conda-forge cuda-compat=12.9`), then after activating the environment, set `export VLLM_ENABLE_CUDA_COMPATIBILITY=1` and `export VLLM_CUDA_COMPATIBILITY_PATH="${CONDA_PREFIX}/cuda-compat"`.
You can verify the configuration works by running a minimal Python script that initializes CUDA via vLLM:
```bash
export VLLM_ENABLE_CUDA_COMPATIBILITY=1
export VLLM_CUDA_COMPATIBILITY_PATH="/usr/local/cuda-12.9/compat"
python3 - << 'EOF'
import vllm
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
EOF
```
Note that we use CUDA 12.9 as an example here, and you may want to install a higher version of cuda-compat package in case vLLM's default CUDA version goes higher.
## ptxas fatal: Value 'sm_110a' is not defined for option 'gpu-name' ## ptxas fatal: Value 'sm_110a' is not defined for option 'gpu-name'
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for CUDA forward compatibility path logic in env_override.py.
Verifies the opt-in LD_LIBRARY_PATH manipulation for CUDA compat libs,
including env var parsing, path detection, and deduplication.
"""
import os
from unittest.mock import patch
import pytest
# Import the functions directly (they're module-level in env_override)
# We must import them without triggering the module-level side effects,
# so we import the functions by name after the module is already loaded.
from vllm.env_override import (
_get_torch_cuda_version,
_maybe_set_cuda_compatibility_path,
)
class TestCudaCompatibilityEnvParsing:
"""Test VLLM_ENABLE_CUDA_COMPATIBILITY env var parsing."""
def test_disabled_by_default(self, monkeypatch):
"""Compat path is NOT set when env var is absent."""
monkeypatch.delenv("VLLM_ENABLE_CUDA_COMPATIBILITY", raising=False)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
assert (
"LD_LIBRARY_PATH" not in os.environ
or os.environ.get("LD_LIBRARY_PATH", "") == ""
)
@pytest.mark.parametrize("value", ["0", "false", "False", "no", ""])
def test_disabled_values(self, monkeypatch, value):
"""Various falsy values should not activate compat path."""
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", value)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
# LD_LIBRARY_PATH should not be set (or remain empty)
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert "compat" not in ld_path
@pytest.mark.parametrize("value", ["1", "true", "True", " 1 ", " TRUE "])
def test_enabled_values_with_valid_path(self, monkeypatch, tmp_path, value):
"""Truthy values activate compat path when a valid path exists."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", value)
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert str(compat_dir) in ld_path
class TestCudaCompatibilityPathDetection:
"""Test path detection: custom override, conda, default."""
def test_custom_path_override(self, monkeypatch, tmp_path):
"""VLLM_CUDA_COMPATIBILITY_PATH takes highest priority."""
custom_dir = tmp_path / "my-compat"
custom_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(custom_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert ld_path.startswith(str(custom_dir))
def test_conda_prefix_fallback(self, monkeypatch, tmp_path):
"""Falls back to $CONDA_PREFIX/cuda-compat if custom not set."""
conda_dir = tmp_path / "conda-env"
compat_dir = conda_dir / "cuda-compat"
compat_dir.mkdir(parents=True)
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.delenv("VLLM_CUDA_COMPATIBILITY_PATH", raising=False)
monkeypatch.setenv("CONDA_PREFIX", str(conda_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert str(compat_dir) in ld_path
def test_no_valid_path_does_nothing(self, monkeypatch):
"""When enabled but no valid path exists, LD_LIBRARY_PATH unchanged."""
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", "/nonexistent/path")
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
with patch("vllm.env_override._get_torch_cuda_version", return_value=None):
_maybe_set_cuda_compatibility_path()
assert os.environ.get("LD_LIBRARY_PATH", "") == ""
def test_default_cuda_path_fallback(self, monkeypatch, tmp_path):
"""Falls back to /usr/local/cuda-{ver}/compat via torch version."""
fake_cuda = tmp_path / "cuda-12.8" / "compat"
fake_cuda.mkdir(parents=True)
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.delenv("VLLM_CUDA_COMPATIBILITY_PATH", raising=False)
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
with (
patch("vllm.env_override._get_torch_cuda_version", return_value="12.8"),
patch(
"vllm.env_override.os.path.isdir",
side_effect=lambda p: p == "/usr/local/cuda-12.8/compat"
or os.path.isdir(p),
),
):
_maybe_set_cuda_compatibility_path()
ld_path = os.environ.get("LD_LIBRARY_PATH", "")
assert "/usr/local/cuda-12.8/compat" in ld_path
class TestCudaCompatibilityLdPathManipulation:
"""Test LD_LIBRARY_PATH prepend and deduplication logic."""
def test_prepends_to_empty_ld_path(self, monkeypatch, tmp_path):
"""Compat path is set when LD_LIBRARY_PATH is empty."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.delenv("LD_LIBRARY_PATH", raising=False)
_maybe_set_cuda_compatibility_path()
assert os.environ["LD_LIBRARY_PATH"] == str(compat_dir)
def test_prepends_to_existing_ld_path(self, monkeypatch, tmp_path):
"""Compat path is prepended before existing entries."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.setenv("LD_LIBRARY_PATH", "/usr/lib:/other/lib")
_maybe_set_cuda_compatibility_path()
ld_path = os.environ["LD_LIBRARY_PATH"]
parts = ld_path.split(os.pathsep)
assert parts[0] == str(compat_dir)
assert "/usr/lib" in parts
assert "/other/lib" in parts
def test_deduplicates_existing_compat_path(self, monkeypatch, tmp_path):
"""If compat path already in LD_LIBRARY_PATH, move to front."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.setenv(
"LD_LIBRARY_PATH",
f"/usr/lib:{compat_dir}:/other/lib",
)
_maybe_set_cuda_compatibility_path()
ld_path = os.environ["LD_LIBRARY_PATH"]
parts = ld_path.split(os.pathsep)
assert parts[0] == str(compat_dir)
assert parts.count(str(compat_dir)) == 1
def test_already_at_front_is_noop(self, monkeypatch, tmp_path):
"""If compat path is already first, don't modify LD_LIBRARY_PATH."""
compat_dir = tmp_path / "compat"
compat_dir.mkdir()
original = f"{compat_dir}:/usr/lib"
monkeypatch.setenv("VLLM_ENABLE_CUDA_COMPATIBILITY", "1")
monkeypatch.setenv("VLLM_CUDA_COMPATIBILITY_PATH", str(compat_dir))
monkeypatch.setenv("LD_LIBRARY_PATH", original)
_maybe_set_cuda_compatibility_path()
assert os.environ["LD_LIBRARY_PATH"] == original
class TestGetTorchCudaVersion:
"""Test _get_torch_cuda_version() helper."""
def test_returns_string_when_torch_available(self):
"""Should return a CUDA version string like '12.8'."""
version = _get_torch_cuda_version()
# torch is installed in vllm's environment
assert version is None or isinstance(version, str)
def test_returns_none_when_torch_missing(self):
"""Should return None when torch is not importable."""
with patch(
"vllm.env_override.importlib.util.find_spec",
return_value=None,
):
assert _get_torch_cuda_version() is None
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E402
import importlib.util
import os import os
def _get_torch_cuda_version():
"""Peripheral function to _maybe_set_cuda_compatibility_path().
PyTorch version must not be determined by importing directly
because it will trigger the CUDA initialization, losing the
chance to set the LD_LIBRARY_PATH beforehand.
"""
try:
spec = importlib.util.find_spec("torch")
if not spec:
return None
if spec.origin:
torch_root = os.path.dirname(spec.origin)
elif spec.submodule_search_locations:
torch_root = spec.submodule_search_locations[0]
else:
return None
version_path = os.path.join(torch_root, "version.py")
if not os.path.exists(version_path):
return None
# Load the version module without importing torch
ver_spec = importlib.util.spec_from_file_location("torch.version", version_path)
if not ver_spec or not ver_spec.loader:
return None
module = importlib.util.module_from_spec(ver_spec)
# Avoid registering in sys.modules to not confuse future imports
ver_spec.loader.exec_module(module)
return getattr(module, "cuda", None)
except Exception:
return None
def _maybe_set_cuda_compatibility_path():
"""Set LD_LIBRARY_PATH for CUDA forward compatibility if enabled.
Must run before 'import torch' since torch loads CUDA shared libraries
at import time and the dynamic linker only consults LD_LIBRARY_PATH when
a library is first loaded.
CUDA forward compatibility is only supported on select professional and
datacenter NVIDIA GPUs. Consumer GPUs (GeForce, RTX) do not support it
and will get Error 803 if compat libs are loaded.
"""
enable = os.environ.get("VLLM_ENABLE_CUDA_COMPATIBILITY", "0").strip().lower() in (
"1",
"true",
)
if not enable:
return
cuda_compat_path = os.environ.get("VLLM_CUDA_COMPATIBILITY_PATH", "")
if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
conda_prefix = os.environ.get("CONDA_PREFIX", "")
conda_compat = os.path.join(conda_prefix, "cuda-compat")
if conda_prefix and os.path.isdir(conda_compat):
cuda_compat_path = conda_compat
if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
torch_cuda_version = _get_torch_cuda_version()
if torch_cuda_version:
default_path = f"/usr/local/cuda-{torch_cuda_version}/compat"
if os.path.isdir(default_path):
cuda_compat_path = default_path
if not cuda_compat_path or not os.path.isdir(cuda_compat_path):
return
norm_path = os.path.normpath(cuda_compat_path)
existing = os.environ.get("LD_LIBRARY_PATH", "")
ld_paths = existing.split(os.pathsep) if existing else []
if ld_paths and ld_paths[0] and os.path.normpath(ld_paths[0]) == norm_path:
return # Already at the front
new_paths = [norm_path] + [
p for p in ld_paths if not p or os.path.normpath(p) != norm_path
]
os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(new_paths)
_maybe_set_cuda_compatibility_path()
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
......
...@@ -239,6 +239,8 @@ if TYPE_CHECKING: ...@@ -239,6 +239,8 @@ if TYPE_CHECKING:
VLLM_WEIGHT_OFFLOADING_DISABLE_UVA: bool = False VLLM_WEIGHT_OFFLOADING_DISABLE_UVA: bool = False
VLLM_DISABLE_LOG_LOGO: bool = False VLLM_DISABLE_LOG_LOGO: bool = False
VLLM_LORA_DISABLE_PDL: bool = False VLLM_LORA_DISABLE_PDL: bool = False
VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
def get_default_cache_root(): def get_default_cache_root():
...@@ -1591,6 +1593,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1591,6 +1593,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes
# Triton compilation to fail. # Triton compilation to fail.
"VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))), "VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))),
# Enable CUDA compatibility mode for datacenter GPUs with older
# driver versions than the CUDA toolkit major version of vLLM.
"VLLM_ENABLE_CUDA_COMPATIBILITY": lambda: (
os.environ.get("VLLM_ENABLE_CUDA_COMPATIBILITY", "0").strip().lower()
in ("1", "true")
),
# Path to the CUDA compatibility libraries when CUDA compatibility is enabled.
"VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get(
"VLLM_CUDA_COMPATIBILITY_PATH", None
),
} }
...@@ -1731,6 +1743,8 @@ def compile_factors() -> dict[str, object]: ...@@ -1731,6 +1743,8 @@ def compile_factors() -> dict[str, object]:
"VLLM_CPU_MOE_PREPACK", "VLLM_CPU_MOE_PREPACK",
"VLLM_CPU_SGL_KERNEL", "VLLM_CPU_SGL_KERNEL",
"VLLM_TEST_FORCE_LOAD_FORMAT", "VLLM_TEST_FORCE_LOAD_FORMAT",
"VLLM_ENABLE_CUDA_COMPATIBILITY",
"VLLM_CUDA_COMPATIBILITY_PATH",
"LOCAL_RANK", "LOCAL_RANK",
"CUDA_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES",
"NO_COLOR", "NO_COLOR",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment