Unverified Commit 61e59274 authored by Rui Qiao's avatar Rui Qiao Committed by GitHub
Browse files

[Core] Introduce SPMD worker execution using Ray accelerated DAG (#6032)


Signed-off-by: default avatarRui Qiao <ruisearch42@gmail.com>
Co-authored-by: default avatarStephanie Wang <swang@cs.berkeley.edu>
parent d25877dd
...@@ -84,6 +84,8 @@ steps: ...@@ -84,6 +84,8 @@ steps:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
...@@ -108,6 +110,7 @@ steps: ...@@ -108,6 +110,7 @@ steps:
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
......
...@@ -6,6 +6,7 @@ from typing import Set, Type, TypeVar, Union ...@@ -6,6 +6,7 @@ from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, MultiModalConfig, LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig, ObservabilityConfig, ParallelConfig,
...@@ -414,6 +415,9 @@ class LLMEngine: ...@@ -414,6 +415,9 @@ class LLMEngine:
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import ( from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor) MultiprocessingGPUExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor executor_class = MultiprocessingGPUExecutor
else: else:
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
...@@ -426,6 +430,7 @@ class LLMEngine: ...@@ -426,6 +430,7 @@ class LLMEngine:
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
) )
return engine return engine
def __reduce__(self): def __reduce__(self):
......
...@@ -34,6 +34,7 @@ if TYPE_CHECKING: ...@@ -34,6 +34,7 @@ if TYPE_CHECKING:
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
...@@ -261,6 +262,13 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -261,6 +262,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)),
# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER":
lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)),
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
......
...@@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def execute_model( def execute_model(
self, execute_model_req: ExecuteModelRequest self,
) -> Optional[List[SamplerOutput]]: execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers( self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop", "start_worker_execution_loop",
...@@ -73,7 +73,9 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -73,7 +73,9 @@ class DistributedGPUExecutor(GPUExecutor):
**self.extra_execute_model_run_workers_kwargs) **self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req) driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
......
import asyncio import asyncio
import os import os
import pickle
from collections import defaultdict from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
...@@ -23,12 +22,30 @@ if TYPE_CHECKING: ...@@ -23,12 +22,30 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor): class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
if self.use_ray_spmd_worker:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.parallel_config.distributed_executor_backend == "ray" assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
...@@ -40,11 +57,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -40,11 +57,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
self.forward_dag = None self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
def _configure_ray_workers_use_nsight(self, def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]: ray_remote_kwargs) -> Dict[str, Any]:
...@@ -110,6 +123,9 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -110,6 +123,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
if self.use_ray_spmd_worker:
self.workers.append(worker)
else:
worker_ip = ray.get(worker.get_node_ip.remote()) worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None: if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it # If the worker is on the same node as the driver, we use it
...@@ -124,7 +140,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -124,7 +140,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Else, added to the list of workers. # Else, added to the list of workers.
self.workers.append(worker) self.workers.append(worker)
if self.driver_dummy_worker is None: if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError( raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider " "Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a " "adjusting the Ray placement group or running the driver on a "
...@@ -254,9 +270,23 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -254,9 +270,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
Passing None will cause the driver to stop the model execution Passing None will cause the driver to stop the model execution
loop running in each of the remote workers. loop running in each of the remote workers.
""" """
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model", return self.driver_worker.execute_method("execute_model",
execute_model_req) execute_model_req)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req))
return outputs[0]
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
...@@ -266,7 +296,6 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -266,7 +296,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
...@@ -281,6 +310,10 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -281,6 +310,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
- all_args/all_kwargs: args/kwargs for each worker are specified - all_args/all_kwargs: args/kwargs for each worker are specified
individually individually
""" """
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode.")
if max_concurrent_workers: if max_concurrent_workers:
raise NotImplementedError( raise NotImplementedError(
...@@ -289,25 +322,21 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -289,25 +322,21 @@ class RayGPUExecutor(DistributedGPUExecutor):
count = len(self.workers) if not \ count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \ async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers) else len(self.non_driver_workers)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
all_worker_args = repeat(args, count) if all_args is None \ all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None) else islice(all_args, first_worker_args_index, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None) else islice(all_kwargs, first_worker_args_index, None)
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else:
# Start the ray workers first. # Start the ray workers first.
ray_workers = self.workers ray_workers = self.workers
if async_run_tensor_parallel_workers_only: if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers ray_workers = self.non_driver_workers
ray_worker_outputs = [ ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, worker.execute_method.remote(method, *worker_args, **worker_kwargs)
**worker_kwargs)
for (worker, worker_args, worker_kwargs for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs) ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
] ]
...@@ -316,44 +345,46 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -316,44 +345,46 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Just return futures # Just return futures
return ray_worker_outputs return ray_worker_outputs
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0] driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
if not use_dummy_driver: if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method( driver_worker_output = [
method, *driver_args, **driver_kwargs) self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
else: else:
assert self.driver_dummy_worker is not None assert self.driver_dummy_worker is not None
driver_worker_output = ray.get( driver_worker_output = [
ray.get(
self.driver_dummy_worker.execute_method.remote( self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs)) method, *driver_args, **driver_kwargs))
]
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs) ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with """Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete.""" async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks) ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self): def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources import pkg_resources
required_version = "2.9" from packaging import version
current_version = pkg_resources.get_distribution("ray").version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version: if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is " raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}") f"required, but found {current_version}")
...@@ -365,23 +396,47 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -365,23 +396,47 @@ class RayGPUExecutor(DistributedGPUExecutor):
# a dummy value for now. It will be fixed soon. # a dummy value for now. It will be fixed soon.
with InputNode() as input_data: with InputNode() as input_data:
forward_dag = MultiOutputNode([ forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote. worker.execute_model_spmd.bind( # type: ignore[attr-defined]
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers input_data) for worker in self.workers
]) ])
return forward_dag.experimental_compile() return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method) self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req)
outputs = await dag_future
return outputs[0]
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if self.pp_locks is None: if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual # This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time # engines can't execute on the same stage at the same time
...@@ -415,8 +470,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -415,8 +470,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return results[-1] return results[-1]
async def _start_worker_execution_loop(self): async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
coros = [ coros = [
worker.execute_method.remote("start_worker_execution_loop") worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers for worker in self.non_driver_workers
] ]
return await asyncio.gather(*coros) return await asyncio.gather(*coros)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
import pickle
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_xpu from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
...@@ -31,16 +31,18 @@ try: ...@@ -31,16 +31,18 @@ try:
gpu_ids = ray.get_gpu_ids() gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids return node_id, gpu_ids
def execute_model_compiled_dag_remote(self, ignored): def execute_model_spmd(self, execute_model_req: ExecuteModelRequest):
"""Used only when compiled DAG is enabled.""" """Used only when SPMD worker and compiled DAG are both
enabled."""
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
import torch import torch
if not self.compiled_dag_cuda_device_set: if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
output = self.worker.execute_model() return self.worker._execute_model_spmd(execute_model_req)
output = pickle.dumps(output)
return output
ray_import_err = None ray_import_err = None
......
import asyncio import asyncio
import os import os
import pickle
from collections import defaultdict from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union) Tuple, Union)
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
...@@ -30,7 +30,7 @@ logger = init_logger(__name__) ...@@ -30,7 +30,7 @@ logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor): class RayXPUExecutor(DistributedGPUExecutor):
...@@ -72,10 +72,9 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -72,10 +72,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self.forward_dag = None self.forward_dag = None
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag() self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
# This is non-None when the execute model loop is running # This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case. # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...@@ -270,7 +269,6 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -270,7 +269,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
...@@ -293,26 +291,20 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -293,26 +291,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None) else islice(all_kwargs, 1, None)
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first. # Start the ray workers first.
ray_worker_outputs = [ ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, worker.execute_method.remote(method, *worker_args, **worker_kwargs)
**worker_kwargs)
for (worker, worker_args, worker_kwargs for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs) ) in zip(self.workers, all_worker_args, all_worker_kwargs)
] ]
if async_run_remote_workers_only: if async_run_remote_workers_only:
# Just return futures # Just return futures
return ray_worker_outputs return ray_worker_outputs
driver_worker_output = []
driver_args = args if all_args is None else all_args[0] driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
if not use_dummy_driver: if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method( driver_worker_output = self.driver_worker.execute_method(
...@@ -324,36 +316,28 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -324,36 +316,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
method, *driver_args, **driver_kwargs)) method, *driver_args, **driver_kwargs))
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs) ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with """Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete.""" async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks) ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self): def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources import pkg_resources
required_version = "2.9" from packaging import version
current_version = pkg_resources.get_distribution("ray").version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version: if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is " raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}") f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray assert self.parallel_config.distributed_executor_backend == "ray"
# Right now, compiled DAG requires at least 1 arg. We send # Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon. # a dummy value for now. It will be fixed soon.
...@@ -363,7 +347,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -363,7 +347,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
bind( # type: ignore[attr-defined] bind( # type: ignore[attr-defined]
input_data) for worker in self.workers input_data) for worker in self.workers
]) ])
return forward_dag.experimental_compile() return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def check_health(self) -> None: def check_health(self) -> None:
"""Raises an error if engine is unhealthy.""" """Raises an error if engine is unhealthy."""
......
...@@ -281,6 +281,33 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -281,6 +281,33 @@ class LocalOrDistributedWorkerBase(WorkerBase):
# list to conform to interface. # list to conform to interface.
return output return output
def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, (
"_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None)
class WorkerWrapperBase: class WorkerWrapperBase:
""" """
...@@ -296,7 +323,7 @@ class WorkerWrapperBase: ...@@ -296,7 +323,7 @@ class WorkerWrapperBase:
trust_remote_code: bool = False) -> None: trust_remote_code: bool = False) -> None:
self.worker_module_name = worker_module_name self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name self.worker_class_name = worker_class_name
self.worker = None self.worker: Optional[WorkerBase] = None
if trust_remote_code: if trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
...@@ -323,7 +350,9 @@ class WorkerWrapperBase: ...@@ -323,7 +350,9 @@ class WorkerWrapperBase:
mod = importlib.import_module(self.worker_module_name) mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name) worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs) self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
def execute_method(self, method, *args, **kwargs): def execute_method(self, method, *args, **kwargs):
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment