Unverified Commit 50cd5674 authored by danisereb's avatar danisereb Committed by GitHub
Browse files

Fix invalid logprobs with MTP enabled and sync scheduling (#38711)


Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
parent 7b1a7423
...@@ -4192,6 +4192,7 @@ class GPUModelRunner( ...@@ -4192,6 +4192,7 @@ class GPUModelRunner(
spec_config = self.speculative_config spec_config = self.speculative_config
propose_drafts_after_bookkeeping = False propose_drafts_after_bookkeeping = False
if spec_config is not None: if spec_config is not None:
# Decide whether to run the drafter or zero out draft tokens.
input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= self.effective_drafter_max_model_len <= self.effective_drafter_max_model_len
...@@ -4227,10 +4228,6 @@ class GPUModelRunner( ...@@ -4227,10 +4228,6 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count( self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
elif ( elif (
spec_config.use_ngram_gpu() spec_config.use_ngram_gpu()
and not spec_config.disable_padded_drafter_batch and not spec_config.disable_padded_drafter_batch
...@@ -4253,15 +4250,20 @@ class GPUModelRunner( ...@@ -4253,15 +4250,20 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count( self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
else: else:
propose_drafts_after_bookkeeping = input_fits_in_drafter propose_drafts_after_bookkeeping = input_fits_in_drafter
if not input_fits_in_drafter:
# Zero out draft tokens so the scheduler doesn't schedule
# stale drafts from the previous step.
# For Nemotron-H: it is necessary to zero out the draft tokens,
# otherwise the stale tokens will corrupt Mamba recurrent
# state and logprobs for sequences near max_model_len.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
with record_function_or_nullcontext("gpu_model_runner: bookkeep"): with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
( (
num_nans_in_logits, num_nans_in_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment