Unverified Commit d90d8eb6 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Async scheduling and PP compatibility with DP (#23770)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 0a2f4c07
...@@ -306,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -306,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0) # Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue()[0] is None assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 1 assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 10 assert scheduler_output.num_scheduled_tokens["0"] == 10
# num_computed_tokens should have been updated immediately. # num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[ assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10 req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1) # Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue()[0] is None assert engine_core.step_with_batch_queue()[0] == {}
assert engine_core.batch_queue.qsize() == 2 assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 2 assert scheduler_output.num_scheduled_tokens["0"] == 2
assert scheduler_output.num_scheduled_tokens["1"] == 8 assert scheduler_output.num_scheduled_tokens["1"] == 8
# num_computed_tokens should have been updated immediately. # num_computed_tokens should have been updated immediately.
...@@ -325,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -325,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert engine_core.scheduler.get_num_unfinished_requests() == 2 assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Batch queue is full. Finish Batch 1. # Finish Batch 1 and schedule Batch 3: (4, req1).
engine_core.step_with_batch_queue() # Note that req0 cannot be scheduled
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# because it is in the decoding stage now. # because it is in the decoding stage now.
engine_core.step_with_batch_queue() engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2 assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 4 assert scheduler_output.num_scheduled_tokens["1"] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0. # Finish Batch 2. Get first token of req0.
# Schedule Batch 4: (1, req0).
output = engine_core.step_with_batch_queue()[0].get(0) output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
scheduler_output = engine_core.batch_queue[-1][1]
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 1 assert scheduler_output.num_scheduled_tokens["0"] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1. # Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
output = engine_core.step_with_batch_queue()[0].get(0) output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
scheduler_output = engine_core.batch_queue[-1][1]
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 1 assert scheduler_output.num_scheduled_tokens["1"] == 1
# Loop until req0 is finished. # Loop until req0 is finished.
step = 0
req_id = 0 req_id = 0
expected_num_tokens = [ expected_num_tokens = [
engine_core.scheduler.requests["0"].num_tokens + 1, engine_core.scheduler.requests["0"].num_tokens + 1,
...@@ -368,8 +358,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -368,8 +358,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
] ]
while engine_core.scheduler.get_num_unfinished_requests() == 2: while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()[0] output = engine_core.step_with_batch_queue()[0]
if step % 2 == 0: # Every step consumes an output.
# Even steps consumes an output.
assert output is not None assert output is not None
assert len(output[0].outputs) == 1 assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests: if req_id in engine_core.scheduler.requests:
...@@ -377,10 +366,6 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -377,10 +366,6 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
req_id].num_tokens == expected_num_tokens[req_id] req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1 expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2 req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
......
...@@ -75,9 +75,10 @@ async def generate( ...@@ -75,9 +75,10 @@ async def generate(
], ],
) )
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind, async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str,
data_parallel_backend: str): async_scheduling: bool):
stats_loggers = {} stats_loggers = {}
...@@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind, ...@@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind,
prompt = "This is a test of data parallel" prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend engine_args.data_parallel_backend = data_parallel_backend
engine_args.async_scheduling = async_scheduling
engine = AsyncLLM.from_engine_args(engine_args, engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger]) stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown) after.callback(engine.shutdown)
......
...@@ -10,6 +10,7 @@ import msgspec ...@@ -10,6 +10,7 @@ import msgspec
import vllm.platforms import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group
from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -136,6 +137,11 @@ try: ...@@ -136,6 +137,11 @@ try:
scheduler_output, intermediate_tensors) scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
output = scheduler_output, output output = scheduler_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
return output return output
def override_env_vars(self, vars: Dict[str, str]): def override_env_vars(self, vars: Dict[str, str]):
......
...@@ -138,12 +138,12 @@ class EngineCore: ...@@ -138,12 +138,12 @@ class EngineCore:
# schedule and execute batches, and is required by pipeline parallelism # schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles. # to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None SchedulerOutput]]] = None
if self.batch_queue_size > 1: if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d", logger.info("Batch queue is enabled with size %d",
self.batch_queue_size) self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size) self.batch_queue = deque(maxlen=self.batch_queue_size)
self.request_block_hasher: Optional[Callable[[Request], self.request_block_hasher: Optional[Callable[[Request],
list[BlockHash]]] = None list[BlockHash]]] = None
...@@ -319,41 +319,43 @@ class EngineCore: ...@@ -319,41 +319,43 @@ class EngineCore:
batch in the job queue is finished. batch in the job queue is finished.
3. Update the scheduler from the output. 3. Update the scheduler from the output.
""" """
assert self.batch_queue is not None batch_queue = self.batch_queue
assert batch_queue is not None
engine_core_outputs = None
scheduler_output = None
# Try to schedule a new batch if the batch queue is not full, but # Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled. # the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking. # Note that this is not blocking.
if not self.batch_queue.full(): assert len(batch_queue) < self.batch_queue_size
model_executed = False
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output) future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait( batch_queue.appendleft(
(future, scheduler_output)) # type: ignore (future, scheduler_output)) # type: ignore[arg-type]
scheduled_batch = (scheduler_output is not None model_executed = scheduler_output.total_num_scheduled_tokens > 0
and scheduler_output.total_num_scheduled_tokens > 0) if model_executed and len(batch_queue) < self.batch_queue_size \
and not batch_queue[-1][0].done():
# If no more requests can be scheduled and the job queue is not empty, # Don't block on next worker response unless the queue is full
# block until the first batch in the job queue is finished. # or there are no more requests to schedule.
# TODO(comaniac): Ideally we should peek the first batch in the return None, True
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe, elif not batch_queue:
# so we need more work. # Queue is empty. We should not reach here since this method should
if not scheduled_batch and not self.batch_queue.empty(): # only be called when the scheduler contains requests or the queue
future, scheduler_output = self.batch_queue.get_nowait() # is non-empty.
return None, False
# Blocking until the first result is available.
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
model_output = self.execute_model_with_error_logging( model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output) lambda _: future.result(), scheduler_output)
self.batch_queue.task_done() engine_core_outputs = self.scheduler.update_from_output(
engine_core_outputs = (self.scheduler.update_from_output( scheduler_output, model_output)
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch return engine_core_outputs, model_executed
def shutdown(self): def shutdown(self):
self.structured_output_manager.clear_backend() self.structured_output_manager.clear_backend()
...@@ -388,7 +390,7 @@ class EngineCore: ...@@ -388,7 +390,7 @@ class EngineCore:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
def execute_dummy_batch(self): def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch") self.model_executor.execute_dummy_batch()
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request) return self.model_executor.add_lora(lora_request)
...@@ -733,7 +735,8 @@ class EngineCoreProc(EngineCore): ...@@ -733,7 +735,8 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed.""" """Exits when an engine step needs to be performed."""
waited = False waited = False
while not self.engines_running and not self.scheduler.has_requests(): while not self.engines_running and not self.scheduler.has_requests() \
and not self.batch_queue:
if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.") logger.debug("EngineCore waiting for work.")
waited = True waited = True
......
...@@ -81,12 +81,10 @@ class Executor(ExecutorBase): ...@@ -81,12 +81,10 @@ class Executor(ExecutorBase):
pass pass
def determine_available_memory(self) -> list[int]: # in bytes def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory") return self.collective_rpc("determine_available_memory")
return output
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec") return self.collective_rpc("get_kv_cache_spec")
return output
def execute_model( def execute_model(
self, self,
...@@ -96,6 +94,9 @@ class Executor(ExecutorBase): ...@@ -96,6 +94,9 @@ class Executor(ExecutorBase):
args=(scheduler_output, )) args=(scheduler_output, ))
return output[0] return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
output = self.collective_rpc("take_draft_token_ids") output = self.collective_rpc("take_draft_token_ids")
return output[0] return output[0]
......
...@@ -191,6 +191,10 @@ class MultiprocExecutor(Executor): ...@@ -191,6 +191,10 @@ class MultiprocExecutor(Executor):
outputs, self.output_rank) outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch",
unique_reply_rank=self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank) # OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids", outputs = self.collective_rpc("take_draft_token_ids",
...@@ -242,12 +246,17 @@ class MultiprocExecutor(Executor): ...@@ -242,12 +246,17 @@ class MultiprocExecutor(Executor):
dequeue_timeout = None if deadline is None else ( dequeue_timeout = None if deadline is None else (
deadline - time.monotonic()) deadline - time.monotonic())
if non_block: if self.io_thread_pool is not None:
# We must consume worker_response_mq from a single thread.
result = self.io_thread_pool.submit( # type: ignore result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event) get_response, w, dequeue_timeout, self.shutdown_event)
else: if not non_block:
result = result.result()
elif not non_block:
result = get_response(w, dequeue_timeout) result = get_response(w, dequeue_timeout)
else:
raise RuntimeError("non_block can only be used when"
" max_concurrent_batches > 1")
responses.append(result) responses.append(result)
return responses return responses
......
...@@ -354,18 +354,22 @@ class Worker(WorkerBase): ...@@ -354,18 +354,22 @@ class Worker(WorkerBase):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None intermediate_tensors = None
if not get_pp_group().is_first_rank: forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors( intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict( get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group())) all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output, output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors) intermediate_tensors)
if isinstance(output, ModelRunnerOutput):
return output
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert parallel_config.distributed_executor_backend != (
"external_launcher") and not get_pp_group().is_last_rank
get_pp_group().send_tensor_dict(output.tensors, get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group()) all_gather_group=get_tp_group())
...@@ -383,9 +387,6 @@ class Worker(WorkerBase): ...@@ -383,9 +387,6 @@ class Worker(WorkerBase):
output.kv_connector_output = kv_connector_output output.kv_connector_output = kv_connector_output
return output return output
assert isinstance(output, ModelRunnerOutput)
return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids() return self.model_runner.take_draft_token_ids()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment