Unverified Commit a4047d4e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Support Eagle3 (no CUDA graph) (#35029)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 965fe459
...@@ -66,6 +66,9 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput ...@@ -66,6 +66,9 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
set_eagle3_aux_hidden_state_layers,
)
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.states import RequestState
...@@ -133,14 +136,42 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -133,14 +136,42 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.output_copy_stream = torch.cuda.Stream(self.device) self.output_copy_stream = torch.cuda.Stream(self.device)
self.output_copy_event = torch.cuda.Event() self.output_copy_event = torch.cuda.Event()
# Pipeline parallelism.
self.pp_size = self.parallel_config.pipeline_parallel_size
self.use_pp = self.pp_size > 1
if self.use_pp:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
self.speculator = None
self.use_aux_hidden_state_outputs = False
if self.speculative_config is not None: if self.speculative_config is not None:
self.do_spec_decode = True self.do_spec_decode = True
self.num_speculative_steps = self.speculative_config.num_speculative_tokens self.num_speculative_steps = self.speculative_config.num_speculative_tokens
if self.is_last_pp_rank:
self.speculator = init_speculator(self.vllm_config, self.device) self.speculator = init_speculator(self.vllm_config, self.device)
if self.speculative_config.method == "eagle3":
# EAGLE3 may require auxiliary hidden states from target model outputs.
self.use_aux_hidden_state_outputs = True
if self.pp_size > 1:
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
else: else:
self.do_spec_decode = False self.do_spec_decode = False
self.num_speculative_steps = 0 self.num_speculative_steps = 0
self.speculator = None
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
self.req_states = RequestState( self.req_states = RequestState(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
...@@ -176,28 +207,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -176,28 +207,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
# LoRA-related workers. # LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs) self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
# KV Connector if configured. # KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pipeline parallelism.
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
if self.use_pp:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len self.req_states.max_model_len = max_model_len
...@@ -220,7 +232,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -220,7 +232,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model = self.load_lora_model( self.model = self.load_lora_model(
self.model, self.vllm_config, self.device self.model, self.vllm_config, self.device
) )
if self.do_spec_decode:
if self.use_aux_hidden_state_outputs:
assert self.speculative_config is not None
set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
if self.speculator is not None:
self.speculator.load_model(self.model) self.speculator.load_model(self.model)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
...@@ -271,7 +287,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -271,7 +287,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config, self.vllm_config, self.device self.kv_cache_config, self.vllm_config, self.device
) )
check_attention_cp_compatibility(self.vllm_config) check_attention_cp_compatibility(self.vllm_config)
if self.do_spec_decode: if self.speculator is not None:
# HACK(woosuk) # HACK(woosuk)
self.speculator.set_attn( self.speculator.set_attn(
self.kv_cache_config, self.kv_cache_config,
...@@ -359,7 +375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -359,7 +375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None return None, None
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, input_batch, _ = self.execute_model_state hidden_states, _, input_batch, _ = self.execute_model_state
assert hidden_states is not None # Last PP rank always has hidden_states assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states return hidden_states, sample_hidden_states
...@@ -399,7 +415,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -399,7 +415,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert sample_hidden_states is not None assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states) self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode: if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp( num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens self.parallel_config.data_parallel_size, self.max_num_tokens
) )
...@@ -465,7 +481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -465,7 +481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
has_lora=self.lora_config is not None, has_lora=self.lora_config is not None,
) )
if self.do_spec_decode: if self.speculator is not None:
self.speculator.capture_model() self.speculator.capture_model()
end_time = time.perf_counter() end_time = time.perf_counter()
...@@ -964,9 +980,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -964,9 +980,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): Here, we don't need to pass the input tensors, # NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers. # because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.cudagraph_manager.run_fullgraph( model_output = self.cudagraph_manager.run_fullgraph(
input_batch.num_tokens_after_padding input_batch.num_tokens_after_padding
) )
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
else: else:
# For piecewise and eager mode, just call model(). # For piecewise and eager mode, just call model().
positions = input_batch.positions positions = input_batch.positions
...@@ -998,12 +1019,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -998,12 +1019,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping=input_batch.slot_mappings, slot_mapping=input_batch.slot_mappings,
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
) )
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output) kv_connector_output = self.kv_connector.post_forward(scheduler_output)
...@@ -1011,12 +1037,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1011,12 +1037,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Non-last PP rank: return IntermediateTensors for sending. # Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors) assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, input_batch, kv_connector_output) self.execute_model_state = (None, None, input_batch, kv_connector_output)
return hidden_states return hidden_states
assert isinstance(hidden_states, torch.Tensor)
# Last rank (or no PP): hidden_states is a tensor for sampling. # Last rank (or no PP): hidden_states is a tensor for sampling.
self.execute_model_state = (hidden_states, input_batch, kv_connector_output) assert isinstance(hidden_states, torch.Tensor)
self.execute_model_state = (
hidden_states,
aux_hidden_states,
input_batch,
kv_connector_output,
)
return None return None
@torch.inference_mode() @torch.inference_mode()
...@@ -1024,7 +1055,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1024,7 +1055,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, grammar_output: GrammarOutput | None self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput | None: ) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, input_batch, kv_connector_output = self.execute_model_state hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None # type: ignore self.execute_model_state = None # type: ignore
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
...@@ -1084,11 +1117,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1084,11 +1117,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.postprocess( self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
) )
if self.do_spec_decode: if self.speculator is not None:
draft_tokens = self.propose_draft( draft_tokens = self.propose_draft(
input_batch, input_batch,
hidden_states, hidden_states,
None, # aux_hidden_states aux_hidden_states,
num_sampled, num_sampled,
num_rejected, num_rejected,
) )
......
...@@ -9,7 +9,7 @@ def init_speculator(vllm_config: VllmConfig, device: torch.device): ...@@ -9,7 +9,7 @@ def init_speculator(vllm_config: VllmConfig, device: torch.device):
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
assert speculative_config is not None assert speculative_config is not None
if speculative_config.use_eagle(): if speculative_config.use_eagle():
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
return EagleSpeculator(vllm_config, device) return EagleSpeculator(vllm_config, device)
raise NotImplementedError(f"{speculative_config.method} is not supported yet.") raise NotImplementedError(f"{speculative_config.method} is not supported yet.")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch.nn as nn
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsEagle3, supports_eagle3
logger = init_logger(__name__)
def set_eagle3_aux_hidden_state_layers(
model: nn.Module,
spec_config: SpeculativeConfig,
) -> None:
if not supports_eagle3(model):
raise RuntimeError("Model does not support EAGLE3 interface")
# mypy may infer the class-level overload for supports_eagle3.
# Narrow explicitly to the runtime protocol instance.
if isinstance(model, type):
raise RuntimeError("Expected model instance for EAGLE3 configuration")
eagle3_model = cast(SupportsEagle3, model)
aux_layers = get_eagle3_aux_layers_from_config(spec_config)
if aux_layers:
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
else:
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers()
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
eagle3_model.set_aux_hidden_state_layers(aux_layers)
def get_eagle3_aux_layers_from_config(
spec_config: SpeculativeConfig,
) -> tuple[int, ...] | None:
if not (spec_config and spec_config.draft_model_config):
return None
hf_config = spec_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
return None
...@@ -9,7 +9,6 @@ from vllm.config import VllmConfig ...@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
...@@ -20,7 +19,8 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -20,7 +19,8 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -73,18 +73,7 @@ class EagleSpeculator: ...@@ -73,18 +73,7 @@ class EagleSpeculator:
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag self.model = load_eagle_model(target_model, self.vllm_config)
with set_model_tag("eagle_head"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=self.draft_model_config
)
share_lm_head = True
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(self.model, "lm_head"):
del self.model.lm_head
self.model.lm_head = target_model.lm_head
def set_attn( def set_attn(
self, self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model
def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Module:
from vllm.compilation.backends import set_model_tag
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
draft_model_config = speculative_config.draft_model_config
with set_model_tag("eagle_head"):
eagle_model = get_model(
vllm_config=vllm_config, model_config=draft_model_config
)
# Share target embeddings when the draft checkpoint does not include
# its own vocab embedding table.
share_embeddings = True
if hasattr(eagle_model, "has_own_embed_tokens"):
share_embeddings = not eagle_model.has_own_embed_tokens
if share_embeddings:
target_language_model = (
target_model.get_language_model()
if hasattr(target_model, "get_language_model")
else target_model
)
inner_model = getattr(target_language_model, "model", None)
target_embed_tokens = None
if inner_model is not None:
if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
if target_embed_tokens is not None and hasattr(eagle_model, "model"):
if hasattr(eagle_model.model, "embed_tokens"):
del eagle_model.model.embed_tokens
eagle_model.model.embed_tokens = target_embed_tokens
# Only share target lm_head when the draft model does not own one.
share_lm_head = True
if hasattr(eagle_model, "has_own_lm_head"):
share_lm_head = not eagle_model.has_own_lm_head
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(eagle_model, "lm_head"):
del eagle_model.lm_head
eagle_model.lm_head = target_model.lm_head
return eagle_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment