Unverified Commit a6fed020 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[V1][PP] Support PP for MultiprocExecutor (#14219)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Signed-off-by: default avatarjiang.li <jiang1.li@intel.com>
parent d419aa5d
......@@ -100,9 +100,8 @@ class PPTestSettings:
eager_mode=True,
chunked_prefill=False),
],
# only ray is supported for V1
distributed_backends=["mp", "ray", "ray"],
vllm_major_versions=["0", "0", "1"],
distributed_backends=["mp", "mp", "ray", "ray"],
vllm_major_versions=["0", "1", "0", "1"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
......@@ -350,6 +349,11 @@ def _compare_tp(
# Temporary. Currently when zeromq + SPMD is used, it does not properly
# terminate because of a Ray Compiled Graph issue.
common_args.append("--disable-frontend-multiprocessing")
elif distributed_backend == "mp":
# Both V0/V1 of multiprocessing executor support PP
pp_env = {
"VLLM_USE_V1": vllm_major_version,
}
else:
pp_env = None
......
......@@ -1338,11 +1338,10 @@ class EngineArgs:
and _warn_or_fallback("Engine in background thread")):
return False
# PP is supported on V1 with Ray distributed executor,
# but off for MP distributed executor for now.
if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend != "ray"):
name = "Pipeline Parallelism without Ray distributed executor"
and self.distributed_executor_backend not in ["ray", "mp"]):
name = "Pipeline Parallelism without Ray distributed executor " \
"or multiprocessing executor"
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False
......
......@@ -8,7 +8,7 @@ import threading
import time
import traceback
import weakref
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
......@@ -53,10 +53,11 @@ class MultiprocExecutor(Executor):
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, (
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}). "
f"Pipeline parallelism is not yet implemented in v1")
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). ")
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
......@@ -104,6 +105,17 @@ class MultiprocExecutor(Executor):
self._ensure_worker_termination(
[w.proc for w in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
def start_worker_monitor(self):
workers = self.workers
self_ref = weakref.ref(self)
......@@ -145,7 +157,9 @@ class MultiprocExecutor(Executor):
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ),
rank0_reply_only=True,
unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches
> 1,
timeout=EXECUTE_MODEL_TIMEOUT_S)
return output
......@@ -154,7 +168,8 @@ class MultiprocExecutor(Executor):
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
rank0_reply_only: bool = False) -> list[Any]:
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
......@@ -171,22 +186,35 @@ class MultiprocExecutor(Executor):
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, rank0_reply_only))
(send_method, args, kwargs, unique_reply_rank))
workers = (self.workers[0], ) if rank0_reply_only else self.workers
responses = [None] * len(workers)
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
workers = (self.workers[unique_reply_rank],
) if unique_reply_rank is not None else self.workers
responses = []
def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=self.shutdown_event)
timeout=dequeue_timeout, cancel=cancel_event)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
return result
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if non_block:
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
else:
result = get_response(w, dequeue_timeout)
responses[w.rank] = result
responses.append(result)
return responses
except TimeoutError as e:
......@@ -225,6 +253,11 @@ class MultiprocExecutor(Executor):
if not getattr(self, 'shutting_down', False):
self.shutting_down = True
self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
for w in self.workers:
w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in self.workers])
......@@ -235,6 +268,22 @@ class MultiprocExecutor(Executor):
self.collective_rpc("check_health", timeout=10)
return
@property
def max_concurrent_batches(self) -> int:
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
@dataclass
class UnreadyWorkerProcHandle:
......@@ -280,12 +329,14 @@ class WorkerProc:
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0,
"is_driver_worker": is_driver_worker,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
......@@ -455,7 +506,7 @@ class WorkerProc:
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
try:
if isinstance(method, str):
......@@ -470,11 +521,11 @@ class WorkerProc:
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
if not rank0_only or self.rank == 0:
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
if not rank0_only or self.rank == 0:
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
......@@ -1016,7 +1016,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
) -> Union[ModelRunnerOutput, IntermediateTensors]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata(
......
......@@ -15,11 +15,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
......@@ -266,7 +267,22 @@ class Worker(WorkerBase):
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
def profile(self, is_start: bool = True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment