Unverified Commit d1d61f33 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Make DP work with connector-delayed new requests (#18559)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarWill Eaton <weaton@redhat.com>
parent 32ce3cf7
...@@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): ...@@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 4 assert len(engine_core.scheduler.running) == 4
# Loop through until they are all done. # Loop through until they are all done.
while len(engine_core.step().outputs) > 0: while len(engine_core.step()[0].outputs) > 0:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
...@@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): ...@@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req0.request_id = req1.request_id = "test" req0.request_id = req1.request_id = "test"
engine_core.add_request(req0) engine_core.add_request(req0)
while len(engine_core.step().outputs) > 0: while len(engine_core.step()[0].outputs) > 0:
pass pass
engine_core.add_request(req1) engine_core.add_request(req1)
while len(engine_core.step().outputs) > 0: while len(engine_core.step()[0].outputs) > 0:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
...@@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): ...@@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done. # Loop through until they are all done.
while len(engine_core.step().outputs) > 0: while len(engine_core.step()[0].outputs) > 0:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
...@@ -296,7 +296,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -296,7 +296,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
engine_core.add_request(req1) engine_core.add_request(req1)
# Schedule Batch 1: (10, req0) # Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 1 assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 10 assert scheduler_output.num_scheduled_tokens[0] == 10
...@@ -305,7 +305,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -305,7 +305,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
req0.request_id].num_computed_tokens == 10 req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1) # Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 2 assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1] scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 2 assert scheduler_output.num_scheduled_tokens[0] == 2
...@@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert scheduler_output.num_scheduled_tokens[1] == 4 assert scheduler_output.num_scheduled_tokens[1] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0. # Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue() output = engine_core.step_with_batch_queue()[0]
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
...@@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert scheduler_output.num_scheduled_tokens[0] == 1 assert scheduler_output.num_scheduled_tokens[0] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1. # Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue() output = engine_core.step_with_batch_queue()[0]
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
...@@ -358,7 +358,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -358,7 +358,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
engine_core.scheduler.requests[1].num_tokens + 1, engine_core.scheduler.requests[1].num_tokens + 1,
] ]
while engine_core.scheduler.get_num_unfinished_requests() == 2: while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue() output = engine_core.step_with_batch_queue()[0]
if step % 2 == 0: if step % 2 == 0:
# Even steps consumes an output. # Even steps consumes an output.
assert output is not None assert output is not None
......
...@@ -101,7 +101,7 @@ def get_forward_context() -> ForwardContext: ...@@ -101,7 +101,7 @@ def get_forward_context() -> ForwardContext:
def set_forward_context(attn_metadata: Any, def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0, virtual_engine: int = 0,
num_tokens: int = 0): num_tokens: Optional[int] = None):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
Here we can inject common logic for every model forward pass. Here we can inject common logic for every model forward pass.
...@@ -111,9 +111,10 @@ def set_forward_context(attn_metadata: Any, ...@@ -111,9 +111,10 @@ def set_forward_context(attn_metadata: Any,
if need_to_track_batchsize: if need_to_track_batchsize:
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1: if vllm_config.parallel_config.data_parallel_size > 1 and (
attn_metadata is not None or num_tokens is not None):
dp_metadata = DPMetadata.make(vllm_config.parallel_config, dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens) attn_metadata, num_tokens or 0)
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
......
...@@ -211,8 +211,12 @@ class EngineCore: ...@@ -211,8 +211,12 @@ class EngineCore:
# Re-raise exception # Re-raise exception
raise err raise err
def step(self) -> EngineCoreOutputs: def step(self) -> tuple[EngineCoreOutputs, bool]:
"""Schedule, execute, and make output.""" """Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished, # Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch. # or finished and not yet removed from the batch.
...@@ -220,15 +224,17 @@ class EngineCore: ...@@ -220,15 +224,17 @@ class EngineCore:
return EngineCoreOutputs( return EngineCoreOutputs(
outputs=[], outputs=[],
scheduler_stats=self.scheduler.make_stats(), scheduler_stats=self.scheduler.make_stats(),
) ), False
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output) model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore scheduler_output, model_output) # type: ignore
return engine_core_outputs return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: def step_with_batch_queue(
self) -> tuple[Optional[EngineCoreOutputs], bool]:
"""Schedule and execute batches with the batch queue. """Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned. Note that if nothing to output in this step, None is returned.
...@@ -273,7 +279,7 @@ class EngineCore: ...@@ -273,7 +279,7 @@ class EngineCore:
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) scheduler_output, model_output)
return engine_core_outputs return engine_core_outputs, scheduled_batch
def shutdown(self): def shutdown(self):
self.structured_output_manager.clear_backend() self.structured_output_manager.clear_backend()
...@@ -537,15 +543,17 @@ class EngineCoreProc(EngineCore): ...@@ -537,15 +543,17 @@ class EngineCoreProc(EngineCore):
req = self.input_queue.get_nowait() req = self.input_queue.get_nowait()
self._handle_client_request(*req) self._handle_client_request(*req)
def _process_engine_step(self): def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests.""" """Called only when there are unfinished local requests."""
# Step the engine core. # Step the engine core.
outputs = self.step_fn() outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue. # Put EngineCoreOutputs into the output queue.
if outputs is not None: if outputs is not None:
self.output_queue.put_nowait(outputs) self.output_queue.put_nowait(outputs)
return model_executed
def _handle_client_request(self, request_type: EngineCoreRequestType, def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
"""Dispatch request from client.""" """Dispatch request from client."""
...@@ -749,30 +757,16 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -749,30 +757,16 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# 2) Step the engine core.
executed = self._process_engine_step()
local_unfinished_reqs = self.scheduler.has_unfinished_requests() local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if local_unfinished_reqs: if not local_unfinished_reqs and not self.engines_running:
# 2) Step the engine core.
self._process_engine_step()
# Check if we have now finished all requests.
local_unfinished_reqs = (
self.scheduler.has_unfinished_requests())
else:
if self.scheduler.has_finished_requests():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self._process_engine_step()
if not self.engines_running:
# All engines are idle. # All engines are idle.
continue continue
# There must be unfinished requests in DP peers, run a # We are in a running state and so must execute a dummy pass
# dummy forward pass. # if the model didn't execute any ready requests.
self.execute_dummy_batch() self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs. # 3) All-reduce operation to determine global unfinished reqs.
......
...@@ -206,7 +206,8 @@ class InprocClient(EngineCoreClient): ...@@ -206,7 +206,8 @@ class InprocClient(EngineCoreClient):
self.engine_core = EngineCore(*args, **kwargs) self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
return self.engine_core.step() outputs, _ = self.engine_core.step()
return outputs
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request) self.engine_core.add_request(request)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment