Unverified Commit 16c472ab authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Move ray-specific WorkerWrapperBase methods to RayWorkerWrapper (#35328)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 3b23d57c
...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform ...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.serial_utils import run_method
from vllm.v1.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -50,6 +51,29 @@ try: ...@@ -50,6 +51,29 @@ try:
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def execute_method(self, method: str | bytes, *args, **kwargs):
try:
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc
# see https://github.com/vllm-project/vllm/issues/3455
msg = (
f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution."
)
logger.exception(msg)
raise e
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
......
...@@ -15,7 +15,6 @@ from vllm.tracing import instrument ...@@ -15,7 +15,6 @@ from vllm.tracing import instrument
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
...@@ -211,15 +210,6 @@ class WorkerWrapperBase: ...@@ -211,15 +210,6 @@ class WorkerWrapperBase:
if self.worker is not None: if self.worker is not None:
self.worker.shutdown() self.worker.shutdown()
def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Adjust the rpc_rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rpc_rank of workers after we create all workers.
"""
if self.rpc_rank in rank_mapping:
self.rpc_rank = rank_mapping[self.rpc_rank]
def update_environment_variables( def update_environment_variables(
self, self,
envs_list: list[dict[str, str]], envs_list: list[dict[str, str]],
...@@ -325,25 +315,6 @@ class WorkerWrapperBase: ...@@ -325,25 +315,6 @@ class WorkerWrapperBase:
# To make vLLM config available during device initialization # To make vLLM config available during device initialization
self.worker.init_device() # type: ignore self.worker.init_device() # type: ignore
def execute_method(self, method: str | bytes, *args, **kwargs):
try:
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (
f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution."
)
logger.exception(msg)
raise e
def __getattr__(self, attr: str): def __getattr__(self, attr: str):
return getattr(self.worker, attr) return getattr(self.worker, attr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment