"docs/vscode:/vscode.git/clone" did not exist on "98ea35601cdb34fdd618f965e7bcc3cb02a677fc"
Unverified Commit 50cd5674 authored by danisereb's avatar danisereb Committed by GitHub
Browse files

Fix invalid logprobs with MTP enabled and sync scheduling (#38711)


Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
parent 7b1a7423
......@@ -4192,6 +4192,7 @@ class GPUModelRunner(
spec_config = self.speculative_config
propose_drafts_after_bookkeeping = False
if spec_config is not None:
# Decide whether to run the drafter or zero out draft tokens.
input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= self.effective_drafter_max_model_len
......@@ -4227,10 +4228,6 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
elif (
spec_config.use_ngram_gpu()
and not spec_config.disable_padded_drafter_batch
......@@ -4253,14 +4250,19 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
else:
propose_drafts_after_bookkeeping = input_fits_in_drafter
if not input_fits_in_drafter:
# Zero out draft tokens so the scheduler doesn't schedule
# stale drafts from the previous step.
# For Nemotron-H: it is necessary to zero out the draft tokens,
# otherwise the stale tokens will corrupt Mamba recurrent
# state and logprobs for sequences near max_model_len.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
else:
propose_drafts_after_bookkeeping = input_fits_in_drafter
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment