Unverified Commit e04a1b6b authored by AlonKejzman's avatar AlonKejzman Committed by GitHub
Browse files

[BUGFIX] Fix crash in Eagle Speculative Decoding models when exceedin… (#24662)


Signed-off-by: default avatarAlonKejzman <alonkeizman@gmail.com>
parent 2e5df88c
...@@ -2310,7 +2310,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2310,7 +2310,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
use_padded_batch_for_eagle = self.speculative_config and \ use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.use_eagle() and \ self.speculative_config.use_eagle() and \
not self.speculative_config.disable_padded_drafter_batch not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle: effective_drafter_max_model_len = self.max_model_len
if effective_drafter_max_model_len is None:
effective_drafter_max_model_len = self.model_config.max_model_len
if (self.speculative_config
and self.speculative_config.draft_model_config is not None
and self.speculative_config.draft_model_config.max_model_len
is not None):
effective_drafter_max_model_len = (
self.speculative_config.draft_model_config.max_model_len)
input_fits_in_drafter = spec_decode_common_attn_metadata and (
spec_decode_common_attn_metadata.seq_lens.max() +
self.speculative_config.num_speculative_tokens
<= effective_drafter_max_model_len)
if use_padded_batch_for_eagle and input_fits_in_drafter:
# EAGLE speculative decoding can use the GPU sampled tokens # EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish. # as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids) propose_draft_token_ids(sampler_output.sampled_token_ids)
...@@ -2328,7 +2341,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2328,7 +2341,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits, hidden_states, logits, hidden_states,
num_scheduled_tokens) num_scheduled_tokens)
if self.speculative_config and not use_padded_batch_for_eagle: if (self.speculative_config and not use_padded_batch_for_eagle
and input_fits_in_drafter):
# ngram and other speculative decoding methods use the sampled # ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping. # tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids) propose_draft_token_ids(valid_sampled_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment