Unverified Commit 4dfad17e authored by wliao2's avatar wliao2 Committed by GitHub
Browse files

replace cuda_device_count_stateless() to current_platform.device_count() (#37841)


Signed-off-by: default avatarLiao, Wei <wei.liao@intel.com>
Signed-off-by: default avatarwliao2 <wei.liao@intel.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent e8057c00
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
from vllm.config import CompilationMode from vllm.config import CompilationMode
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
from ...utils import compare_all_settings from ...utils import compare_all_settings
...@@ -109,10 +108,10 @@ def test_compile_correctness( ...@@ -109,10 +108,10 @@ def test_compile_correctness(
tp_size = test_setting.tp_size tp_size = test_setting.tp_size
attn_backend = test_setting.attn_backend attn_backend = test_setting.attn_backend
method = test_setting.method method = test_setting.method
if cuda_device_count_stateless() < pp_size * tp_size: if current_platform.device_count() < pp_size * tp_size:
pytest.skip( pytest.skip(
f"Need at least {pp_size}*{tp_size} CUDA gpus but got " f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}" f"{current_platform.device_count()}"
) )
final_args = [ final_args = [
......
...@@ -412,7 +412,7 @@ def test_cudagraph_sizes_post_init( ...@@ -412,7 +412,7 @@ def test_cudagraph_sizes_post_init(
with ( with (
ctx, ctx,
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size), patch.object(current_platform, "device_count", return_value=tp_size),
): ):
kwargs = {} kwargs = {}
if cudagraph_capture_sizes is not None: if cudagraph_capture_sizes is not None:
......
...@@ -13,7 +13,6 @@ from vllm.distributed.utils import StatelessProcessGroup ...@@ -13,7 +13,6 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
...@@ -21,7 +20,7 @@ from ..utils import multi_gpu_test ...@@ -21,7 +20,7 @@ from ..utils import multi_gpu_test
@ray.remote @ray.remote
class _CUDADeviceCountStatelessTestActor: class _CUDADeviceCountStatelessTestActor:
def get_count(self): def get_count(self):
return cuda_device_count_stateless() return current_platform.device_count()
def set_cuda_visible_devices(self, cuda_visible_devices: str): def set_cuda_visible_devices(self, cuda_visible_devices: str):
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
......
...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from vllm.utils.torch_utils import cuda_device_count_stateless, set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from .modular_kernel_tools.common import ( from .modular_kernel_tools.common import (
...@@ -310,10 +310,10 @@ def test_modular_kernel_combinations_multigpu( ...@@ -310,10 +310,10 @@ def test_modular_kernel_combinations_multigpu(
world_size: int, world_size: int,
pytestconfig, pytestconfig,
): ):
if cuda_device_count_stateless() < world_size: if current_platform.device_count() < world_size:
pytest.skip( pytest.skip(
f"Not enough GPUs available to run, got " f"Not enough GPUs available to run, got "
f"{cuda_device_count_stateless()} expected " f"{current_platform.device_count()} expected "
f"{world_size}." f"{world_size}."
) )
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.reload.meta import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.reload.meta import (
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
def test_move_metatensors(): def test_move_metatensors():
...@@ -140,7 +139,7 @@ def test_get_numel_loaded(): ...@@ -140,7 +139,7 @@ def test_get_numel_loaded():
], ],
) )
def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
if cuda_device_count_stateless() < tp_size: if current_platform.device_count() < tp_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough CUDA devices")
if "FP8" in base_model and not current_platform.supports_fp8(): if "FP8" in base_model and not current_platform.supports_fp8():
...@@ -206,8 +205,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): ...@@ -206,8 +205,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
def test_online_quantize_reload( def test_online_quantize_reload(
base_model, mul_model, add_model, quantization, tp_size, vllm_runner base_model, mul_model, add_model, quantization, tp_size, vllm_runner
): ):
if cuda_device_count_stateless() < tp_size: if current_platform.device_count() < tp_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough GPU devices")
if quantization == "fp8" and not current_platform.supports_fp8(): if quantization == "fp8" and not current_platform.supports_fp8():
pytest.skip(reason="Requires FP8 support") pytest.skip(reason="Requires FP8 support")
......
...@@ -21,8 +21,8 @@ import lm_eval ...@@ -21,8 +21,8 @@ import lm_eval
import pytest import pytest
from packaging import version from packaging import version
from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx950 from vllm.platforms.rocm import on_gfx950
from vllm.utils.torch_utils import cuda_device_count_stateless
MODEL_ACCURACIES = { MODEL_ACCURACIES = {
# Full quantization: attention linears and MoE linears # Full quantization: attention linears and MoE linears
...@@ -89,7 +89,7 @@ def test_gpt_oss_attention_quantization( ...@@ -89,7 +89,7 @@ def test_gpt_oss_attention_quantization(
expected_accuracy: float, expected_accuracy: float,
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
): ):
if tp_size > cuda_device_count_stateless(): if tp_size > current_platform.device_count():
pytest.skip("Not enough GPUs to run this test case") pytest.skip("Not enough GPUs to run this test case")
if "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8" in model_name and on_gfx950(): if "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8" in model_name and on_gfx950():
......
...@@ -58,7 +58,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser ...@@ -58,7 +58,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GB_bytes from vllm.utils.mem_constants import GB_bytes
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
cuda_device_count_stateless,
set_random_seed, # noqa: F401 - re-exported for use in test files set_random_seed, # noqa: F401 - re-exported for use in test files
) )
...@@ -384,7 +383,7 @@ class RemoteVLLMServer: ...@@ -384,7 +383,7 @@ class RemoteVLLMServer:
elif current_platform.is_cuda(): elif current_platform.is_cuda():
with _nvml(): with _nvml():
total_used = 0 total_used = 0
device_count = cuda_device_count_stateless() device_count = current_platform.device_count()
for i in range(device_count): for i in range(device_count):
handle = nvmlDeviceGetHandleByIndex(i) handle = nvmlDeviceGetHandleByIndex(i)
mem_info = nvmlDeviceGetMemoryInfo(handle) mem_info = nvmlDeviceGetMemoryInfo(handle)
...@@ -1497,7 +1496,7 @@ def multi_gpu_marks(*, num_gpus: int): ...@@ -1497,7 +1496,7 @@ def multi_gpu_marks(*, num_gpus: int):
"""Get a collection of pytest marks to apply for `@multi_gpu_test`.""" """Get a collection of pytest marks to apply for `@multi_gpu_test`."""
test_selector = pytest.mark.distributed(num_gpus=num_gpus) test_selector = pytest.mark.distributed(num_gpus=num_gpus)
test_skipif = pytest.mark.skipif( test_skipif = pytest.mark.skipif(
cuda_device_count_stateless() < num_gpus, current_platform.device_count() < num_gpus,
reason=f"Need at least {num_gpus} GPUs to run the test.", reason=f"Need at least {num_gpus} GPUs to run the test.",
) )
...@@ -1529,7 +1528,7 @@ def gpu_tier_mark(*, min_gpus: int = 1, max_gpus: int | None = None): ...@@ -1529,7 +1528,7 @@ def gpu_tier_mark(*, min_gpus: int = 1, max_gpus: int | None = None):
@gpu_tier_mark(max_gpus=1) # only on single-GPU @gpu_tier_mark(max_gpus=1) # only on single-GPU
@gpu_tier_mark(min_gpus=2, max_gpus=4) # 2-4 GPUs only @gpu_tier_mark(min_gpus=2, max_gpus=4) # 2-4 GPUs only
""" """
gpu_count = cuda_device_count_stateless() gpu_count = current_platform.device_count()
marks = [] marks = []
if min_gpus > 1: if min_gpus > 1:
......
...@@ -11,8 +11,8 @@ from tests.v1.shutdown.utils import ( ...@@ -11,8 +11,8 @@ from tests.v1.shutdown.utils import (
) )
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
...@@ -34,7 +34,7 @@ async def test_async_llm_delete( ...@@ -34,7 +34,7 @@ async def test_async_llm_delete(
tensor_parallel_size: degree of tensor parallelism tensor_parallel_size: degree of tensor parallelism
send_one_request: send one request to engine before deleting send_one_request: send one request to engine before deleting
""" """
if cuda_device_count_stateless() < tensor_parallel_size: if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough CUDA devices")
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
...@@ -83,7 +83,7 @@ def test_llm_delete( ...@@ -83,7 +83,7 @@ def test_llm_delete(
enable_multiprocessing: enable workers in separate process(es) enable_multiprocessing: enable workers in separate process(es)
send_one_request: send one request to engine before deleting send_one_request: send one request to engine before deleting
""" """
if cuda_device_count_stateless() < tensor_parallel_size: if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m: with monkeypatch.context() as m:
......
...@@ -15,7 +15,7 @@ from tests.v1.shutdown.utils import ( ...@@ -15,7 +15,7 @@ from tests.v1.shutdown.utils import (
from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm import LLM, AsyncEngineArgs, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.platforms import current_platform
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
...@@ -60,7 +60,7 @@ async def test_async_llm_model_error( ...@@ -60,7 +60,7 @@ async def test_async_llm_model_error(
AsyncLLM always uses an MP client. AsyncLLM always uses an MP client.
""" """
if cuda_device_count_stateless() < tensor_parallel_size: if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model. # Monkeypatch an error in the model.
...@@ -126,7 +126,7 @@ def test_llm_model_error( ...@@ -126,7 +126,7 @@ def test_llm_model_error(
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
and >1 rank and >1 rank
""" """
if cuda_device_count_stateless() < tensor_parallel_size: if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m: with monkeypatch.context() as m:
......
...@@ -15,7 +15,7 @@ from vllm import LLM ...@@ -15,7 +15,7 @@ from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.platforms import current_platform
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
...@@ -57,7 +57,7 @@ def test_async_llm_startup_error( ...@@ -57,7 +57,7 @@ def test_async_llm_startup_error(
Test profiling (forward()) and load weights failures. Test profiling (forward()) and load weights failures.
AsyncLLM always uses an MP client. AsyncLLM always uses an MP client.
""" """
if cuda_device_count_stateless() < tensor_parallel_size: if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model. # Monkeypatch an error in the model.
...@@ -99,7 +99,7 @@ def test_llm_startup_error( ...@@ -99,7 +99,7 @@ def test_llm_startup_error(
# If MODELS list grows, each architecture needs its own test variant. # If MODELS list grows, each architecture needs its own test variant.
if model != "JackFram/llama-68m": if model != "JackFram/llama-68m":
pytest.skip(reason="Only test JackFram/llama-68m") pytest.skip(reason="Only test JackFram/llama-68m")
if cuda_device_count_stateless() < tensor_parallel_size: if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices") pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m: with monkeypatch.context() as m:
......
...@@ -10,6 +10,8 @@ import regex as re ...@@ -10,6 +10,8 @@ import regex as re
_TORCH_CUDA_PATTERNS = [ _TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b", r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
r"\bwith\storch\.cuda\.device\b", r"\bwith\storch\.cuda\.device\b",
# Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally
r"\bcuda_device_count_stateless\(\)\b",
] ]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"} ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
......
...@@ -16,7 +16,6 @@ from vllm.config.utils import config ...@@ -16,7 +16,6 @@ from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_ports_list from vllm.utils.network_utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv from ray.runtime_env import RuntimeEnv
...@@ -726,9 +725,9 @@ class ParallelConfig: ...@@ -726,9 +725,9 @@ class ParallelConfig:
backend = "mp" backend = "mp"
elif ( elif (
current_platform.is_cuda() current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size and current_platform.device_count() < self.world_size
): ):
gpu_count = cuda_device_count_stateless() gpu_count = current_platform.device_count()
raise ValueError( raise ValueError(
f"World size ({self.world_size}) is larger than the number of " f"World size ({self.world_size}) is larger than the number of "
f"available GPUs ({gpu_count}) in this node. If this is " f"available GPUs ({gpu_count}) in this node. If this is "
......
...@@ -19,8 +19,8 @@ import torch.multiprocessing as mp ...@@ -19,8 +19,8 @@ import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -320,7 +320,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: ...@@ -320,7 +320,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
is_distributed = dist.is_initialized() is_distributed = dist.is_initialized()
num_dev = cuda_device_count_stateless() num_dev = current_platform.device_count()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None: if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
......
...@@ -17,7 +17,6 @@ from vllm.distributed.device_communicators.all_reduce_utils import ( ...@@ -17,7 +17,6 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
try: try:
ops.meta_size() ops.meta_size()
...@@ -135,7 +134,7 @@ class CustomAllreduce: ...@@ -135,7 +134,7 @@ class CustomAllreduce:
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
else: else:
device_ids = list(range(cuda_device_count_stateless())) device_ids = list(range(current_platform.device_count()))
physical_device_id = device_ids[device.index] physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
......
...@@ -13,7 +13,6 @@ from vllm.config import get_current_vllm_config_or_none ...@@ -13,7 +13,6 @@ from vllm.config import get_current_vllm_config_or_none
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -137,7 +136,7 @@ class QuickAllReduce: ...@@ -137,7 +136,7 @@ class QuickAllReduce:
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
else: else:
device_ids = list(range(cuda_device_count_stateless())) device_ids = list(range(current_platform.device_count()))
physical_device_id = device_ids[device.index] physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [ gather_list = [
......
...@@ -9,7 +9,7 @@ from __future__ import annotations ...@@ -9,7 +9,7 @@ from __future__ import annotations
import os import os
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from functools import cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
import torch import torch
...@@ -20,9 +20,9 @@ from typing_extensions import ParamSpec ...@@ -20,9 +20,9 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
import vllm._C_stable_libtorch # noqa import vllm._C_stable_libtorch # noqa
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
...@@ -47,6 +47,32 @@ pynvml = import_pynvml() ...@@ -47,6 +47,32 @@ pynvml = import_pynvml()
torch.backends.cuda.enable_cudnn_sdp(False) torch.backends.cuda.enable_cudnn_sdp(False)
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
"""Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES
at the time of call.
This should be used instead of torch.accelerator.device_count() unless
CUDA_VISIBLE_DEVICES has already been set to the desired value.
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released."""
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
if not torch.cuda._is_compiled():
return 0
raw_count = torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
@cache @cache
def _get_backend_priorities( def _get_backend_priorities(
use_mla: bool, use_mla: bool,
...@@ -456,7 +482,7 @@ class CudaPlatformBase(Platform): ...@@ -456,7 +482,7 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
@classmethod @classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
......
...@@ -13,7 +13,6 @@ from torch.distributed.distributed_c10d import is_nccl_available ...@@ -13,7 +13,6 @@ from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
...@@ -67,6 +66,38 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { ...@@ -67,6 +66,38 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
} }
@lru_cache(maxsize=8)
def _rocm_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
"""Get number of ROCm devices, caching based on the value of CUDA_VISIBLE_DEVICES
at the time of call.
This should be used instead of torch.accelerator.device_count() unless
CUDA_VISIBLE_DEVICES has already been set to the desired value.
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released."""
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
if not torch.cuda._is_compiled():
return 0
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = (
torch.cuda._device_count_amdsmi()
if (hasattr(torch.cuda, "_device_count_amdsmi"))
else -1
)
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
def _sync_hip_cuda_env_vars(): def _sync_hip_cuda_env_vars():
"""Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent. """Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent.
Treats empty string as unset. Raises on genuine conflicts.""" Treats empty string as unset. Raises on genuine conflicts."""
...@@ -810,7 +841,7 @@ class RocmPlatform(Platform): ...@@ -810,7 +841,7 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return cuda_device_count_stateless() return _rocm_device_count_stateless(getattr(envs, cls.device_control_env_var))
@classmethod @classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype): def check_if_supports_dtype(cls, dtype: torch.dtype):
......
...@@ -22,7 +22,6 @@ import vllm.envs as envs ...@@ -22,7 +22,6 @@ import vllm.envs as envs
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.platform_utils import cuda_get_device_properties from vllm.utils.platform_utils import cuda_get_device_properties
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -196,7 +195,7 @@ class UsageMessage: ...@@ -196,7 +195,7 @@ class UsageMessage:
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
self.gpu_count = cuda_device_count_stateless() self.gpu_count = current_platform.device_count()
self.gpu_type, self.gpu_memory_per_device = cuda_get_device_properties( self.gpu_type, self.gpu_memory_per_device = cuda_get_device_properties(
0, ("name", "total_memory") 0, ("name", "total_memory")
) )
......
...@@ -6,7 +6,6 @@ import os ...@@ -6,7 +6,6 @@ import os
import random import random
import threading import threading
from collections.abc import Callable, Collection from collections.abc import Callable, Collection
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, TypeVar
import numpy as np import numpy as np
...@@ -16,7 +15,6 @@ from packaging import version ...@@ -16,7 +15,6 @@ from packaging import version
from packaging.version import Version from packaging.version import Version
from torch.library import Library, infer_schema from torch.library import Library, infer_schema
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -590,49 +588,6 @@ def aux_stream() -> torch.cuda.Stream | None: ...@@ -590,49 +588,6 @@ def aux_stream() -> torch.cuda.Stream | None:
return _aux_stream return _aux_stream
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
import torch.version
from vllm.platforms import current_platform
if not torch.cuda._is_compiled():
return 0
if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = (
torch.cuda._device_count_amdsmi()
if (hasattr(torch.cuda, "_device_count_amdsmi"))
else -1
)
else:
raw_count = torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.accelerator.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensor(tensor: Any) -> Any:
""" """
Create a weak reference to a tensor. Create a weak reference to a tensor.
......
...@@ -369,9 +369,7 @@ def initialize_ray_cluster( ...@@ -369,9 +369,7 @@ def initialize_ray_cluster(
# Prevalidate GPU requirements before Ray processing # Prevalidate GPU requirements before Ray processing
if current_platform.is_cuda() and parallel_config.world_size > 1: if current_platform.is_cuda() and parallel_config.world_size > 1:
from vllm.utils.torch_utils import cuda_device_count_stateless available_gpus = current_platform.device_count()
available_gpus = cuda_device_count_stateless()
if parallel_config.world_size > available_gpus: if parallel_config.world_size > available_gpus:
logger.warning( logger.warning(
"Tensor parallel size (%d) exceeds available GPUs (%d). " "Tensor parallel size (%d) exceeds available GPUs (%d). "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment