Unverified Commit 3e440786 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Fully support for async scheduling + PP, 30.8% E2E throughput...


[Feature] Fully support for async scheduling + PP, 30.8% E2E throughput improvement, 31.8% TPOT improvement (#32618)
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 8bdd3979
...@@ -42,6 +42,22 @@ def test_compile_config_repr_succeeds(): ...@@ -42,6 +42,22 @@ def test_compile_config_repr_succeeds():
assert "inductor_passes" in val assert "inductor_passes" in val
def test_async_scheduling_with_pipeline_parallelism_is_allowed():
cfg = VllmConfig(
scheduler_config=SchedulerConfig(
max_model_len=8192,
is_encoder_decoder=False,
async_scheduling=True,
),
parallel_config=ParallelConfig(
pipeline_parallel_size=2,
distributed_executor_backend="mp",
nnodes=2,
),
)
assert cfg.scheduler_config.async_scheduling is True
@dataclass @dataclass
class _TestConfigFields: class _TestConfigFields:
a: int a: int
......
...@@ -127,6 +127,21 @@ def test_schedule_multimodal_requests(): ...@@ -127,6 +127,21 @@ def test_schedule_multimodal_requests():
assert len(encoder_input) == 1 assert len(encoder_input) == 1
def test_async_scheduling_pp_allows_rescheduling_with_output_placeholders():
"""Async scheduling + PP: allow multi-step in-flight scheduling per request"""
scheduler = create_scheduler(async_scheduling=True, pipeline_parallel_size=2)
(req,) = create_requests(num_requests=1, num_tokens=8)
scheduler.add_request(req)
_ = scheduler.schedule()
assert req.num_output_placeholders > 0
# before any update_from_output, we still expect the request can be
# scheduled again (multi-step in-flight).
output = scheduler.schedule()
assert req.request_id in output.num_scheduled_tokens
def test_schedule_partial_requests(): def test_schedule_partial_requests():
"""Test scheduling behavior with partial requests. """Test scheduling behavior with partial requests.
......
...@@ -344,13 +344,6 @@ class Scheduler(SchedulerInterface): ...@@ -344,13 +344,6 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
if self.use_pp and request.num_output_placeholders > 0:
req_index += 1
continue
if ( if (
request.num_output_placeholders > 0 request.num_output_placeholders > 0
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1). # This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
...@@ -1003,7 +996,10 @@ class Scheduler(SchedulerInterface): ...@@ -1003,7 +996,10 @@ class Scheduler(SchedulerInterface):
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
if self.use_pp: # NOTE: In PP+async scheduling, we consume token ids via a direct GPU
# broadcast path (`input_batch.prev_sampled_token_ids`), so we can
# omit this payload.
if self.use_pp and not self.scheduler_config.async_scheduling:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
......
...@@ -1010,20 +1010,26 @@ class GPUModelRunner( ...@@ -1010,20 +1010,26 @@ class GPUModelRunner(
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank: if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back, if not req_data.new_token_ids:
# because there's no direct communication between the first- # Async scheduled PP: Sampled tokens propagated via GPU broadcast.
# stage worker and the last-stage worker. new_token_ids: list[int] = []
new_token_ids = req_data.new_token_ids[i] else:
# Add the sampled token(s) from the previous step (if any). # Non-async scheduling with PP: The scheduler sends
# This doesn't include "unverified" tokens like spec tokens. # sampled token ids back because there's no direct communication
num_new_tokens = ( # between the first-stage worker and the last-stage worker.
num_computed_tokens + len(new_token_ids) - req_state.num_tokens new_token_ids = req_data.new_token_ids[i]
) # Add the sampled token(s) from the previous step (if any).
if num_new_tokens == 1: # This doesn't include "unverified" tokens like spec tokens.
# Avoid slicing list in most common case. num_new_tokens = (
req_state.output_token_ids.append(new_token_ids[-1]) num_computed_tokens + len(new_token_ids) - req_state.num_tokens
elif num_new_tokens > 0: )
req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:]
)
elif num_output_tokens < len(req_state.output_token_ids): elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load # Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state. # failure. Align the cached state.
...@@ -3577,7 +3583,9 @@ class GPUModelRunner( ...@@ -3577,7 +3583,9 @@ class GPUModelRunner(
self.kv_connector_output = None self.kv_connector_output = None
if self.execute_model_state is None: if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used. # receive sampled token ids from the last PP rank.
if self.use_async_scheduling and get_pp_group().world_size > 1:
self._pp_receive_prev_sampled_token_ids_to_input_batch()
if not kv_connector_output: if not kv_connector_output:
return None # type: ignore[return-value] return None # type: ignore[return-value]
...@@ -3618,6 +3626,12 @@ class GPUModelRunner( ...@@ -3618,6 +3626,12 @@ class GPUModelRunner(
self._update_states_after_model_execute( self._update_states_after_model_execute(
sampler_output.sampled_token_ids, scheduler_output sampler_output.sampled_token_ids, scheduler_output
) )
if self.use_async_scheduling:
pp = get_pp_group()
if pp.world_size > 1 and pp.is_last_rank:
self._pp_broadcast_prev_sampled_token_ids(
sampler_output.sampled_token_ids
)
self._draft_token_ids = None self._draft_token_ids = None
self._draft_token_req_ids = None self._draft_token_req_ids = None
...@@ -3753,6 +3767,45 @@ class GPUModelRunner( ...@@ -3753,6 +3767,45 @@ class GPUModelRunner(
return async_output return async_output
def _pp_broadcast_prev_sampled_token_ids(
self, sampled_token_ids: torch.Tensor
) -> None:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp = get_pp_group()
assert pp.is_last_rank
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
)
torch.distributed.broadcast(
sampled_token_ids, src=pp.rank, group=pp.device_group
)
def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None:
"""Receive sampled token ids broadcast from last PP stage"""
pp = get_pp_group()
assert not pp.is_last_rank
num_reqs = self.input_batch.num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device)
torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
self.input_batch.prev_sampled_token_ids = recv
# construct `prev_req_id_to_index` here so `_prepare_input_ids`
# can map req_id -> previous batch row
discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0]
discard_req_indices_set = set(discard_req_indices)
prev_req_id_to_index: dict[str, int] = {}
for i, req_id in enumerate(self.input_batch.req_ids):
if i in discard_req_indices_set:
continue
prev_req_id_to_index[req_id] = i
# PP+async scheduling: advance per-request local cached output length by
# appending a placeholder (-1) token id.
if (req_state := self.requests.get(req_id)) is not None:
req_state.output_token_ids.append(-1)
self.input_batch.prev_req_id_to_index = prev_req_id_to_index
def take_draft_token_ids(self) -> DraftTokenIds | None: def take_draft_token_ids(self) -> DraftTokenIds | None:
if not self.num_spec_tokens or not self._draft_token_req_ids: if not self.num_spec_tokens or not self._draft_token_req_ids:
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment