Unverified Commit 8d84d836 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[BugFix][Spec Decode] Fix hidden size mismatch between target and eagle head (#17740)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 950b7118
...@@ -28,23 +28,25 @@ class EagleProposer: ...@@ -28,23 +28,25 @@ class EagleProposer:
device: torch.device, device: torch.device,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.method = self.vllm_config.speculative_config.method self.speculative_config = vllm_config.speculative_config
self.num_speculative_tokens = ( self.draft_model_config = self.speculative_config.draft_model_config
vllm_config.speculative_config.num_speculative_tokens) self.method = self.speculative_config.method
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_tokens = vllm_config.scheduler_config \ self.block_size = vllm_config.cache_config.block_size
.max_num_batched_tokens self.num_speculative_tokens = (
self.speculative_config.num_speculative_tokens)
self.hidden_size = vllm_config.model_config.get_hidden_size() self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes)) self.vllm_config.compilation_config.cudagraph_capture_sizes))
...@@ -56,7 +58,6 @@ class EagleProposer: ...@@ -56,7 +58,6 @@ class EagleProposer:
self.positions = torch.zeros(self.max_num_tokens, self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64, dtype=torch.int64,
device=device) device=device)
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
...@@ -131,7 +132,6 @@ class EagleProposer: ...@@ -131,7 +132,6 @@ class EagleProposer:
num_input_tokens = num_tokens num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
...@@ -209,7 +209,6 @@ class EagleProposer: ...@@ -209,7 +209,6 @@ class EagleProposer:
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
# Run the model. # Run the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment