Unverified Commit 74374386 authored by Sairam Pillai's avatar Sairam Pillai Committed by GitHub
Browse files

[Bugfix] Improve GPU validation logging in Ray fallback scenarios (#25775)


Signed-off-by: default avatarSairam Pillai <sairam.pillai61@gmail.com>
parent c01f6e52
...@@ -521,15 +521,11 @@ class ParallelConfig: ...@@ -521,15 +521,11 @@ class ParallelConfig:
current_platform.is_cuda() current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size and cuda_device_count_stateless() < self.world_size
): ):
if not ray_found: gpu_count = cuda_device_count_stateless()
raise ValueError( raise ValueError(
"Unable to load Ray: " f"Tensor parallel size ({self.world_size}) cannot be "
f"{ray_utils.ray_import_err}. Ray is " f"larger than the number of available GPUs ({gpu_count})."
"required for multi-node inference, " )
"please install Ray with `pip install "
"ray`."
)
backend = "ray"
elif self.data_parallel_backend == "ray": elif self.data_parallel_backend == "ray":
logger.info( logger.info(
"Using ray distributed inference because " "Using ray distributed inference because "
......
...@@ -255,12 +255,33 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): ...@@ -255,12 +255,33 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
try: try:
ray.get(pg_ready_ref, timeout=0) ray.get(pg_ready_ref, timeout=0)
except ray.exceptions.GetTimeoutError: except ray.exceptions.GetTimeoutError:
raise ValueError( # Provide more helpful error message when GPU count is exceeded
"Cannot provide a placement group of " total_gpu_required = sum(spec.get("GPU", 0) for spec in placement_group_specs)
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " # If more than one GPU is required for the placement group, provide a
"`ray status` and `ray list nodes` to make sure the cluster has " # more specific error message.
"enough resources." # We use >1 here because multi-GPU (tensor parallel) jobs are more
) from None # likely to fail due to insufficient cluster resources, and users may
# need to adjust tensor_parallel_size to fit available GPUs.
if total_gpu_required > 1:
raise ValueError(
f"Cannot provide a placement group requiring "
f"{total_gpu_required} GPUs "
f"(placement_group_specs={placement_group_specs}) within "
f"{PG_WAIT_TIMEOUT} seconds.\n"
f"Tensor parallel size may exceed available GPUs in your "
f"cluster. Check resources with `ray status` and "
f"`ray list nodes`.\n"
f"If running on K8s with limited GPUs, consider reducing "
f"--tensor-parallel-size to match available GPU resources."
) from None
else:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within "
f"{PG_WAIT_TIMEOUT} seconds. See "
"`ray status` and `ray list nodes` to make sure the cluster "
"has enough resources."
) from None
def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
...@@ -299,6 +320,23 @@ def initialize_ray_cluster( ...@@ -299,6 +320,23 @@ def initialize_ray_cluster(
assert_ray_available() assert_ray_available()
from vllm.platforms import current_platform from vllm.platforms import current_platform
# Prevalidate GPU requirements before Ray processing
if current_platform.is_cuda() and parallel_config.world_size > 1:
from vllm.utils import cuda_device_count_stateless
available_gpus = cuda_device_count_stateless()
if parallel_config.world_size > available_gpus:
logger.warning(
"Tensor parallel size (%d) exceeds available GPUs (%d). "
"This may result in Ray placement group allocation failures. "
"Consider reducing tensor_parallel_size to %d or less, "
"or ensure your Ray cluster has %d GPUs available.",
parallel_config.world_size,
available_gpus,
available_gpus,
parallel_config.world_size,
)
if ray.is_initialized(): if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.") logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu(): elif current_platform.is_rocm() or current_platform.is_xpu():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment