Unverified Commit a6fed020 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[V1][PP] Support PP for MultiprocExecutor (#14219)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Signed-off-by: default avatarjiang.li <jiang1.li@intel.com>
parent d419aa5d
...@@ -100,9 +100,8 @@ class PPTestSettings: ...@@ -100,9 +100,8 @@ class PPTestSettings:
eager_mode=True, eager_mode=True,
chunked_prefill=False), chunked_prefill=False),
], ],
# only ray is supported for V1 distributed_backends=["mp", "mp", "ray", "ray"],
distributed_backends=["mp", "ray", "ray"], vllm_major_versions=["0", "1", "0", "1"],
vllm_major_versions=["0", "0", "1"],
task=task, task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
load_format=load_format), load_format=load_format),
...@@ -350,6 +349,11 @@ def _compare_tp( ...@@ -350,6 +349,11 @@ def _compare_tp(
# Temporary. Currently when zeromq + SPMD is used, it does not properly # Temporary. Currently when zeromq + SPMD is used, it does not properly
# terminate because of a Ray Compiled Graph issue. # terminate because of a Ray Compiled Graph issue.
common_args.append("--disable-frontend-multiprocessing") common_args.append("--disable-frontend-multiprocessing")
elif distributed_backend == "mp":
# Both V0/V1 of multiprocessing executor support PP
pp_env = {
"VLLM_USE_V1": vllm_major_version,
}
else: else:
pp_env = None pp_env = None
......
...@@ -1338,11 +1338,10 @@ class EngineArgs: ...@@ -1338,11 +1338,10 @@ class EngineArgs:
and _warn_or_fallback("Engine in background thread")): and _warn_or_fallback("Engine in background thread")):
return False return False
# PP is supported on V1 with Ray distributed executor,
# but off for MP distributed executor for now.
if (self.pipeline_parallel_size > 1 if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend != "ray"): and self.distributed_executor_backend not in ["ray", "mp"]):
name = "Pipeline Parallelism without Ray distributed executor" name = "Pipeline Parallelism without Ray distributed executor " \
"or multiprocessing executor"
_raise_or_fallback(feature_name=name, recommend_to_remove=False) _raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False return False
......
...@@ -8,7 +8,7 @@ import threading ...@@ -8,7 +8,7 @@ import threading
import time import time
import traceback import traceback
import weakref import weakref
from concurrent.futures import Future from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
...@@ -53,10 +53,11 @@ class MultiprocExecutor(Executor): ...@@ -53,10 +53,11 @@ class MultiprocExecutor(Executor):
self.world_size = self.parallel_config.world_size self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, ( pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the " f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}). " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"Pipeline parallelism is not yet implemented in v1") f"_parallel_size ({pp_parallel_size}). ")
# Set multiprocessing envs that are common to V0 and V1 # Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config) set_multiprocessing_worker_envs(self.parallel_config)
...@@ -104,6 +105,17 @@ class MultiprocExecutor(Executor): ...@@ -104,6 +105,17 @@ class MultiprocExecutor(Executor):
self._ensure_worker_termination( self._ensure_worker_termination(
[w.proc for w in unready_workers]) [w.proc for w in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
def start_worker_monitor(self): def start_worker_monitor(self):
workers = self.workers workers = self.workers
self_ref = weakref.ref(self) self_ref = weakref.ref(self)
...@@ -145,7 +157,9 @@ class MultiprocExecutor(Executor): ...@@ -145,7 +157,9 @@ class MultiprocExecutor(Executor):
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model", (output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ), args=(scheduler_output, ),
rank0_reply_only=True, unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches
> 1,
timeout=EXECUTE_MODEL_TIMEOUT_S) timeout=EXECUTE_MODEL_TIMEOUT_S)
return output return output
...@@ -154,7 +168,8 @@ class MultiprocExecutor(Executor): ...@@ -154,7 +168,8 @@ class MultiprocExecutor(Executor):
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: tuple = (), args: tuple = (),
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
rank0_reply_only: bool = False) -> list[Any]: non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed: if self.is_failed:
raise RuntimeError("Executor failed.") raise RuntimeError("Executor failed.")
...@@ -171,22 +186,35 @@ class MultiprocExecutor(Executor): ...@@ -171,22 +186,35 @@ class MultiprocExecutor(Executor):
send_method = cloudpickle.dumps( send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL) method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue( self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, rank0_reply_only)) (send_method, args, kwargs, unique_reply_rank))
workers = (self.workers[0], ) if rank0_reply_only else self.workers workers = (self.workers[unique_reply_rank],
responses = [None] * len(workers) ) if unique_reply_rank is not None else self.workers
for w in workers: responses = []
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic()) def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
status, result = w.worker_response_mq.dequeue( status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=self.shutdown_event) timeout=dequeue_timeout, cancel=cancel_event)
if status != WorkerProc.ResponseStatus.SUCCESS: if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError( raise RuntimeError(
f"Worker failed with error '{result}', please check the" f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause") " stack trace above for the root cause")
return result
responses[w.rank] = result for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if non_block:
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
else:
result = get_response(w, dequeue_timeout)
responses.append(result)
return responses return responses
except TimeoutError as e: except TimeoutError as e:
...@@ -225,6 +253,11 @@ class MultiprocExecutor(Executor): ...@@ -225,6 +253,11 @@ class MultiprocExecutor(Executor):
if not getattr(self, 'shutting_down', False): if not getattr(self, 'shutting_down', False):
self.shutting_down = True self.shutting_down = True
self.shutdown_event.set() self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
for w in self.workers: for w in self.workers:
w.worker_response_mq = None w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in self.workers]) self._ensure_worker_termination([w.proc for w in self.workers])
...@@ -235,6 +268,22 @@ class MultiprocExecutor(Executor): ...@@ -235,6 +268,22 @@ class MultiprocExecutor(Executor):
self.collective_rpc("check_health", timeout=10) self.collective_rpc("check_health", timeout=10)
return return
@property
def max_concurrent_batches(self) -> int:
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
@dataclass @dataclass
class UnreadyWorkerProcHandle: class UnreadyWorkerProcHandle:
...@@ -280,12 +329,14 @@ class WorkerProc: ...@@ -280,12 +329,14 @@ class WorkerProc:
all_kwargs: list[dict] = [ all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size) {} for _ in range(vllm_config.parallel_config.world_size)
] ]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
all_kwargs[rank] = { all_kwargs[rank] = {
"vllm_config": vllm_config, "vllm_config": vllm_config,
"local_rank": local_rank, "local_rank": local_rank,
"rank": rank, "rank": rank,
"distributed_init_method": distributed_init_method, "distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0, "is_driver_worker": is_driver_worker,
} }
wrapper.init_worker(all_kwargs) wrapper.init_worker(all_kwargs)
self.worker = wrapper self.worker = wrapper
...@@ -455,7 +506,7 @@ class WorkerProc: ...@@ -455,7 +506,7 @@ class WorkerProc:
def worker_busy_loop(self): def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers""" """Main busy loop for Multiprocessing Workers"""
while True: while True:
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue() method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
try: try:
if isinstance(method, str): if isinstance(method, str):
...@@ -470,11 +521,11 @@ class WorkerProc: ...@@ -470,11 +521,11 @@ class WorkerProc:
logger.exception("WorkerProc hit an exception.") logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to # exception might not be serializable, so we convert it to
# string, only for logging purpose. # string, only for logging purpose.
if not rank0_only or self.rank == 0: if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e))) (WorkerProc.ResponseStatus.FAILURE, str(e)))
continue continue
if not rank0_only or self.rank == 0: if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output)) (WorkerProc.ResponseStatus.SUCCESS, output))
...@@ -1016,7 +1016,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1016,7 +1016,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
# Update KVConnector with the KVConnector metadata forward(). # Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata( get_kv_transfer_group().bind_connector_metadata(
......
...@@ -15,11 +15,12 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -15,11 +15,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
...@@ -266,7 +267,22 @@ class Worker(WorkerBase): ...@@ -266,7 +267,22 @@ class Worker(WorkerBase):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output) intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None return output if self.is_driver_worker else None
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment