Unverified Commit 4b8826b3 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

refactor(vllm): scope ray import to scale_elastic_ep, remove module-level ray state (#7618)


Signed-off-by: default avatartzulingk <tzulingk@nvidia.com>
parent 7edb07b5
...@@ -13,16 +13,7 @@ import time ...@@ -13,16 +13,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Final,
Generic,
Optional,
TypeVar,
)
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -65,48 +56,6 @@ from .multimodal_utils.hash_utils import compute_mm_uuids_from_images ...@@ -65,48 +56,6 @@ from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
if TYPE_CHECKING:
import ray
import ray.util.state as _ray_util_state
try:
import ray
import ray.util.state as _ray_util_state
except ModuleNotFoundError:
ray = None
_ray_util_state = None
# TODO(upstream-vllm): remove this patch once vLLM fixes add_dp_placement_groups in
# vllm/v1/engine/utils.py to use ray.nodes() instead of ray.util.state.list_nodes().
#
# Patch ray.util.state.list_nodes to use the GCS API instead of the dashboard HTTP
# API (127.0.0.1:8265/api/v0/nodes). The dynamo image installs ray core only (not
# ray[default]), so the dashboard HTTP server starts in --minimal mode with the HTTP
# server disabled. vLLM's add_dp_placement_groups calls list_nodes() which requires
# that HTTP endpoint, causing scale_elastic_ep to fail with "Failed to connect to
# API server".
#
# ray.nodes() uses the GCS gRPC channel directly (no dashboard process needed) and
# returns the same information. This patch makes elastic EP scaling self-contained.
#
# Format mapping:
# list_nodes() → objects with .node_ip and .node_id
# ray.nodes() → dicts with "NodeManagerAddress" and "NodeID"
class _NodeInfo:
__slots__ = ("node_ip", "node_id")
def __init__(self, d: dict) -> None:
self.node_ip: str = d["NodeManagerAddress"]
self.node_id: str = d["NodeID"]
if ray is not None and _ray_util_state is not None:
_ray_util_state.list_nodes = lambda **kw: [
_NodeInfo(n) for n in ray.nodes() if n.get("Alive", False)
]
# Multimodal data dictionary keys # Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url" IMAGE_URL_KEY: Final = "image_url"
VIDEO_URL_KEY: Final = "video_url" VIDEO_URL_KEY: Final = "video_url"
...@@ -576,6 +525,39 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -576,6 +525,39 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
logger.info(f"[ElasticEP] Scaling to new_data_parallel_size={new_dp_size}") logger.info(f"[ElasticEP] Scaling to new_data_parallel_size={new_dp_size}")
try: try:
# TODO(upstream-vllm): remove this patch once vLLM fixes
# add_dp_placement_groups in vllm/v1/engine/utils.py to use ray.nodes()
# instead of ray.util.state.list_nodes().
#
# Patch ray.util.state.list_nodes to use the GCS API instead of the
# dashboard HTTP API (127.0.0.1:8265/api/v0/nodes). The dynamo image
# installs ray core only (not ray[default]), so the dashboard HTTP server
# starts in --minimal mode with the HTTP server disabled. vLLM's
# add_dp_placement_groups calls list_nodes() which requires that HTTP
# endpoint, causing scale_elastic_ep to fail with "Failed to connect to
# API server".
#
# ray.nodes() uses the GCS gRPC channel directly (no dashboard process
# needed) and returns the same information. Imported lazily so ray is not
# required at module load time (absent in non-elastic-EP deployments).
#
# Format mapping:
# list_nodes() → objects with .node_ip and .node_id
# ray.nodes() → dicts with "NodeManagerAddress" and "NodeID"
import ray
import ray.util.state as _ray_util_state
class _NodeInfo:
__slots__ = ("node_id", "node_ip")
def __init__(self, d: dict) -> None:
self.node_ip: str = d["NodeManagerAddress"]
self.node_id: str = d["NodeID"]
_ray_util_state.list_nodes = lambda **kw: [
_NodeInfo(n) for n in ray.nodes() if n.get("Alive", False)
]
await self.engine_client.scale_elastic_ep(new_dp_size) await self.engine_client.scale_elastic_ep(new_dp_size)
logger.info(f"[ElasticEP] Scaling to dp={new_dp_size} complete") logger.info(f"[ElasticEP] Scaling to dp={new_dp_size} complete")
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment