Unverified Commit 21467f9a authored by Eldar Kurtić's avatar Eldar Kurtić Committed by GitHub
Browse files

Enable Eagle3 speculative decoding for GPT-OSS model (#25246)


Signed-off-by: default avatarEldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
parent f92d9526
...@@ -527,7 +527,7 @@ class SpeculativeConfig: ...@@ -527,7 +527,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got " "speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}") f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen"] eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
if self.method == "eagle3" and self.target_model_config and not any( if self.method == "eagle3" and self.target_model_config and not any(
supported_model in supported_model in
self.target_model_config.hf_text_config.model_type self.target_model_config.hf_text_config.model_type
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
from .interfaces import SupportsPP from .interfaces import SupportsEagle3, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
...@@ -238,6 +238,7 @@ class GptOssModel(nn.Module): ...@@ -238,6 +238,7 @@ class GptOssModel(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size)) ["hidden_states", "residual"], self.config.hidden_size))
self.aux_hidden_state_layers = tuple[int, ...]()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids) return self.embedding(input_ids)
...@@ -261,8 +262,12 @@ class GptOssModel(nn.Module): ...@@ -261,8 +262,12 @@ class GptOssModel(nn.Module):
x = intermediate_tensors["hidden_states"] x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x +
residual)
x, residual = layer(x, positions, residual) x, residual = layer(x, positions, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -270,6 +275,9 @@ class GptOssModel(nn.Module): ...@@ -270,6 +275,9 @@ class GptOssModel(nn.Module):
"residual": residual "residual": residual
}) })
x, _ = self.norm(x, residual) x, _ = self.norm(x, residual)
if len(aux_hidden_states) > 0:
return x, aux_hidden_states
return x return x
def _load_weights_mxfp4( def _load_weights_mxfp4(
...@@ -610,7 +618,7 @@ class GptOssModel(nn.Module): ...@@ -610,7 +618,7 @@ class GptOssModel(nn.Module):
weights, stacked_params_mapping) weights, stacked_params_mapping)
class GptOssForCausalLM(nn.Module, SupportsPP): class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
...@@ -658,6 +666,13 @@ class GptOssForCausalLM(nn.Module, SupportsPP): ...@@ -658,6 +666,13 @@ class GptOssForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -823,15 +823,29 @@ class EagleProposer: ...@@ -823,15 +823,29 @@ class EagleProposer:
else: else:
target_language_model = target_model target_language_model = target_model
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \ if get_pp_group().world_size == 1:
and self.model.model.embed_tokens.weight.shape \ if hasattr(target_language_model.model, 'embed_tokens'):
== target_language_model.model.embed_tokens.weight.shape: target_embed_tokens = target_language_model.model.embed_tokens
logger.info( elif hasattr(target_language_model.model, 'embedding'):
"Assuming the EAGLE head shares the same vocab embedding" target_embed_tokens = target_language_model.model.embedding
" with the target model.") else:
del self.model.model.embed_tokens raise AttributeError(
self.model.model.embed_tokens = ( "Target model does not have 'embed_tokens' or 'embedding' "
target_language_model.model.embed_tokens) "attribute")
# Check if shapes match and we found the embedding
eagle_shape = self.model.model.embed_tokens.weight.shape
target_shape = target_embed_tokens.weight.shape
if eagle_shape == target_shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding"
" with the target model.")
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"
" from the target model.")
else: else:
logger.info( logger.info(
"The EAGLE head's vocab embedding will be loaded separately" "The EAGLE head's vocab embedding will be loaded separately"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment