"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "efa6bed264b2dbb4c5d7a28e49fab60f6c69aef2"
Unverified Commit d90d8eb6 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Async scheduling and PP compatibility with DP (#23770)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 0a2f4c07
...@@ -306,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -306,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0) # Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue()[0] is None assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 1 assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 10 assert scheduler_output.num_scheduled_tokens["0"] == 10
# num_computed_tokens should have been updated immediately. # num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[ assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10 req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1) # Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue()[0] is None assert engine_core.step_with_batch_queue()[0] == {}
assert engine_core.batch_queue.qsize() == 2 assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 2 assert scheduler_output.num_scheduled_tokens["0"] == 2
assert scheduler_output.num_scheduled_tokens["1"] == 8 assert scheduler_output.num_scheduled_tokens["1"] == 8
# num_computed_tokens should have been updated immediately. # num_computed_tokens should have been updated immediately.
...@@ -325,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -325,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert engine_core.scheduler.get_num_unfinished_requests() == 2 assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Batch queue is full. Finish Batch 1. # Finish Batch 1 and schedule Batch 3: (4, req1).
engine_core.step_with_batch_queue() # Note that req0 cannot be scheduled
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# because it is in the decoding stage now. # because it is in the decoding stage now.
engine_core.step_with_batch_queue() engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2 assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 4 assert scheduler_output.num_scheduled_tokens["1"] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0. # Finish Batch 2. Get first token of req0.
# Schedule Batch 4: (1, req0).
output = engine_core.step_with_batch_queue()[0].get(0) output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
scheduler_output = engine_core.batch_queue[-1][1]
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 1 assert scheduler_output.num_scheduled_tokens["0"] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1. # Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
output = engine_core.step_with_batch_queue()[0].get(0) output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
scheduler_output = engine_core.batch_queue[-1][1]
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 1 assert scheduler_output.num_scheduled_tokens["1"] == 1
# Loop until req0 is finished. # Loop until req0 is finished.
step = 0
req_id = 0 req_id = 0
expected_num_tokens = [ expected_num_tokens = [
engine_core.scheduler.requests["0"].num_tokens + 1, engine_core.scheduler.requests["0"].num_tokens + 1,
...@@ -368,19 +358,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -368,19 +358,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
] ]
while engine_core.scheduler.get_num_unfinished_requests() == 2: while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()[0] output = engine_core.step_with_batch_queue()[0]
if step % 2 == 0: # Every step consumes an output.
# Even steps consumes an output. assert output is not None
assert output is not None assert len(output[0].outputs) == 1
assert len(output[0].outputs) == 1 if req_id in engine_core.scheduler.requests:
if req_id in engine_core.scheduler.requests: assert engine_core.scheduler.requests[
assert engine_core.scheduler.requests[ req_id].num_tokens == expected_num_tokens[req_id]
req_id].num_tokens == expected_num_tokens[req_id] expected_num_tokens[req_id] += 1
expected_num_tokens[req_id] += 1 req_id = (req_id + 1) % 2
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
......
...@@ -75,9 +75,10 @@ async def generate( ...@@ -75,9 +75,10 @@ async def generate(
], ],
) )
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind, async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str,
data_parallel_backend: str): async_scheduling: bool):
stats_loggers = {} stats_loggers = {}
...@@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind, ...@@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind,
prompt = "This is a test of data parallel" prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend engine_args.data_parallel_backend = data_parallel_backend
engine_args.async_scheduling = async_scheduling
engine = AsyncLLM.from_engine_args(engine_args, engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger]) stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown) after.callback(engine.shutdown)
......
...@@ -10,6 +10,7 @@ import msgspec ...@@ -10,6 +10,7 @@ import msgspec
import vllm.platforms import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group
from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -136,6 +137,11 @@ try: ...@@ -136,6 +137,11 @@ try:
scheduler_output, intermediate_tensors) scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
output = scheduler_output, output output = scheduler_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
return output return output
def override_env_vars(self, vars: Dict[str, str]): def override_env_vars(self, vars: Dict[str, str]):
......
...@@ -138,12 +138,12 @@ class EngineCore: ...@@ -138,12 +138,12 @@ class EngineCore:
# schedule and execute batches, and is required by pipeline parallelism # schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles. # to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None SchedulerOutput]]] = None
if self.batch_queue_size > 1: if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d", logger.info("Batch queue is enabled with size %d",
self.batch_queue_size) self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size) self.batch_queue = deque(maxlen=self.batch_queue_size)
self.request_block_hasher: Optional[Callable[[Request], self.request_block_hasher: Optional[Callable[[Request],
list[BlockHash]]] = None list[BlockHash]]] = None
...@@ -319,41 +319,43 @@ class EngineCore: ...@@ -319,41 +319,43 @@ class EngineCore:
batch in the job queue is finished. batch in the job queue is finished.
3. Update the scheduler from the output. 3. Update the scheduler from the output.
""" """
assert self.batch_queue is not None batch_queue = self.batch_queue
assert batch_queue is not None
engine_core_outputs = None
scheduler_output = None
# Try to schedule a new batch if the batch queue is not full, but # Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled. # the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking. # Note that this is not blocking.
if not self.batch_queue.full(): assert len(batch_queue) < self.batch_queue_size
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore
scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)
# If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished.
# TODO(comaniac): Ideally we should peek the first batch in the
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available. model_executed = False
model_output = self.execute_model_with_error_logging( if self.scheduler.has_requests():
lambda _: future.result(), scheduler_output) scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output)
batch_queue.appendleft(
(future, scheduler_output)) # type: ignore[arg-type]
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if model_executed and len(batch_queue) < self.batch_queue_size \
and not batch_queue[-1][0].done():
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
elif not batch_queue:
# Queue is empty. We should not reach here since this method should
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
self.batch_queue.task_done() engine_core_outputs = self.scheduler.update_from_output(
engine_core_outputs = (self.scheduler.update_from_output( scheduler_output, model_output)
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch return engine_core_outputs, model_executed
def shutdown(self): def shutdown(self):
self.structured_output_manager.clear_backend() self.structured_output_manager.clear_backend()
...@@ -388,7 +390,7 @@ class EngineCore: ...@@ -388,7 +390,7 @@ class EngineCore:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
def execute_dummy_batch(self): def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch") self.model_executor.execute_dummy_batch()
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request) return self.model_executor.add_lora(lora_request)
...@@ -733,7 +735,8 @@ class EngineCoreProc(EngineCore): ...@@ -733,7 +735,8 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed.""" """Exits when an engine step needs to be performed."""
waited = False waited = False
while not self.engines_running and not self.scheduler.has_requests(): while not self.engines_running and not self.scheduler.has_requests() \
and not self.batch_queue:
if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.") logger.debug("EngineCore waiting for work.")
waited = True waited = True
......
...@@ -81,12 +81,10 @@ class Executor(ExecutorBase): ...@@ -81,12 +81,10 @@ class Executor(ExecutorBase):
pass pass
def determine_available_memory(self) -> list[int]: # in bytes def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory") return self.collective_rpc("determine_available_memory")
return output
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec") return self.collective_rpc("get_kv_cache_spec")
return output
def execute_model( def execute_model(
self, self,
...@@ -96,6 +94,9 @@ class Executor(ExecutorBase): ...@@ -96,6 +94,9 @@ class Executor(ExecutorBase):
args=(scheduler_output, )) args=(scheduler_output, ))
return output[0] return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
output = self.collective_rpc("take_draft_token_ids") output = self.collective_rpc("take_draft_token_ids")
return output[0] return output[0]
......
...@@ -191,6 +191,10 @@ class MultiprocExecutor(Executor): ...@@ -191,6 +191,10 @@ class MultiprocExecutor(Executor):
outputs, self.output_rank) outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch",
unique_reply_rank=self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank) # OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids", outputs = self.collective_rpc("take_draft_token_ids",
...@@ -242,12 +246,17 @@ class MultiprocExecutor(Executor): ...@@ -242,12 +246,17 @@ class MultiprocExecutor(Executor):
dequeue_timeout = None if deadline is None else ( dequeue_timeout = None if deadline is None else (
deadline - time.monotonic()) deadline - time.monotonic())
if non_block: if self.io_thread_pool is not None:
# We must consume worker_response_mq from a single thread.
result = self.io_thread_pool.submit( # type: ignore result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event) get_response, w, dequeue_timeout, self.shutdown_event)
else: if not non_block:
result = result.result()
elif not non_block:
result = get_response(w, dequeue_timeout) result = get_response(w, dequeue_timeout)
else:
raise RuntimeError("non_block can only be used when"
" max_concurrent_batches > 1")
responses.append(result) responses.append(result)
return responses return responses
......
...@@ -354,36 +354,37 @@ class Worker(WorkerBase): ...@@ -354,36 +354,37 @@ class Worker(WorkerBase):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None intermediate_tensors = None
if not get_pp_group().is_first_rank: forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors( intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict( get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group())) all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output, output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors) intermediate_tensors)
if isinstance(output, ModelRunnerOutput):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \ assert parallel_config.distributed_executor_backend != (
and not get_pp_group().is_last_rank: "external_launcher") and not get_pp_group().is_last_rank
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors, get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group()) all_gather_group=get_tp_group())
kv_connector_output = output.kv_connector_output kv_connector_output = output.kv_connector_output
if not kv_connector_output: if not kv_connector_output:
return None return None
# In case of PP with kv transfer, we need to pass through the # In case of PP with kv transfer, we need to pass through the
# kv_connector_output # kv_connector_output
if (not kv_connector_output.finished_sending if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving): and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
assert isinstance(output, ModelRunnerOutput) output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment