"benchmarks/vscode:/vscode.git/clone" did not exist on "a3a51d20e7d040542118f04f5089c57a27bc7aca"
Unverified Commit 9a6a66f3 authored by Zijing Liu's avatar Zijing Liu Committed by GitHub
Browse files

[MRv2]fix: model accuracy regression caused by reusing the stale...


[MRv2]fix: model accuracy regression caused by reusing the stale last_sampled_tokens and draft_tokens (#39833)
Signed-off-by: default avatarZijing Liu <liuzijing2014@gmail.com>
parent 67eb6083
......@@ -102,6 +102,18 @@ class RequestState:
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
if num_computed_tokens > 0 and num_computed_tokens <= prefill_len:
# For PD disagg or resumed requests: set last_sampled to the last
# computed token so the first decode step gets the right input_id.
# For fresh prefill requests (num_computed_tokens == 0) the tensor
# is not read by combine_sampled_and_draft_tokens so we skip the
# write. Use a slice assignment rather than scalar indexing so the
# write is dispatched through fill_ without a host/device sync.
self.last_sampled_tokens[req_idx : req_idx + 1] = all_token_ids[
num_computed_tokens - 1
]
self.draft_tokens[req_idx].zero_()
def apply_staged_writes(self) -> None:
self.prompt_len.copy_to_uva()
self.prefill_len.copy_to_uva()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment