Unverified Commit 747256bb authored by Jing Wang's avatar Jing Wang Committed by GitHub
Browse files

[Bugfix][Core] Fix stuck chunked pipeline parallelism with async scheduling (#38726)


Signed-off-by: default avatarJing Wang <jingwang96@qq.com>
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
parent 1174723e
......@@ -3758,6 +3758,15 @@ class GPUModelRunner(
return slot_mappings_by_gid, slot_mappings_by_layer
def _is_all_reqs_chunked_prefill(self) -> bool:
"""Check if all scheduled requests are marked to discard sampled tokens.
This is true when `discard_request_mask` is set for every scheduled
request (e.g., for chunked prefill requests that are not the last
prefill chunk)."""
num_reqs = self.input_batch.num_reqs
return bool(self.discard_request_mask.np[:num_reqs].all())
@torch.inference_mode()
def execute_model(
self,
......@@ -4361,6 +4370,9 @@ class GPUModelRunner(
assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
)
# Skip for chunked prefill: sampled tokens are dummy
# and will be discarded, no need to broadcast.
if not self._is_all_reqs_chunked_prefill():
torch.distributed.broadcast(
sampled_token_ids, src=pp.rank, group=pp.device_group
)
......@@ -4372,6 +4384,8 @@ class GPUModelRunner(
num_reqs = self.input_batch.num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device)
# skip for chunked prefill.
if not self._is_all_reqs_chunked_prefill():
torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
self.input_batch.prev_sampled_token_ids = recv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment