Unverified Commit 8b346309 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Refactor] Consolidate SupportsEagle (#36063)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent 54a6db82
...@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.models.interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsPP,
)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
...@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module): ...@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module):
return loaded_params return loaded_params
class StepDecoderModel(nn.Module): class StepDecoderModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module): ...@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = getattr(
config, "aux_hidden_state_layers", ()
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], ["hidden_states", "residual"],
config.hidden_size, config.hidden_size,
...@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module): ...@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
if idx in self.aux_hidden_state_layers:
if residual is None:
aux_hidden_states.append(hidden_states)
else:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
...@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module): ...@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module):
return hidden_states return hidden_states
class Step1ForCausalLM(nn.Module, SupportsPP): class Step1ForCausalLM(nn.Module, SupportsPP, SupportsEagle, SupportsEagle3):
packed_modules_mapping = STEP_PACKED_MODULES_MAPPING packed_modules_mapping = STEP_PACKED_MODULES_MAPPING
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -618,6 +618,6 @@ class Base( ...@@ -618,6 +618,6 @@ class Base(
# Ensure that the capture hooks are installed before dynamo traces the model # Ensure that the capture hooks are installed before dynamo traces the model
maybe_install_capturing_hooks(self.model) maybe_install_capturing_hooks(self.model)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)
...@@ -27,7 +27,7 @@ def set_eagle3_aux_hidden_state_layers( ...@@ -27,7 +27,7 @@ def set_eagle3_aux_hidden_state_layers(
if aux_layers: if aux_layers:
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers) logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
else: else:
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers() aux_layers = eagle3_model.get_eagle3_default_aux_hidden_state_layers()
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers) logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
eagle3_model.set_aux_hidden_state_layers(aux_layers) eagle3_model.set_aux_hidden_state_layers(aux_layers)
......
...@@ -4556,7 +4556,9 @@ class GPUModelRunner( ...@@ -4556,7 +4556,9 @@ class GPUModelRunner(
aux_layers, aux_layers,
) )
else: else:
aux_layers = self.model.get_eagle3_aux_hidden_state_layers() aux_layers = (
self.model.get_eagle3_default_aux_hidden_state_layers()
)
self.model.set_aux_hidden_state_layers(aux_layers) self.model.set_aux_hidden_state_layers(aux_layers)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment