Unverified Commit 39474513 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] fix draft attention metadata generation (#37364)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 638a872d
...@@ -30,7 +30,10 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: ...@@ -30,7 +30,10 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
def init_attn_backend( def init_attn_backend(
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
active_layer_names: set[str] | None = None,
): ):
attn_backends: dict[str, type[AttentionBackend]] = {} attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = [] attn_groups: list[list[AttentionGroup]] = []
...@@ -39,6 +42,8 @@ def init_attn_backend( ...@@ -39,6 +42,8 @@ def init_attn_backend(
kv_cache_config.kv_cache_groups kv_cache_config.kv_cache_groups
): ):
layer_names = kv_cache_group_spec.layer_names layer_names = kv_cache_group_spec.layer_names
if active_layer_names is not None:
layer_names = list(active_layer_names.intersection(layer_names))
layer_type = cast(type[Any], AttentionLayerBase) layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names) attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
......
...@@ -350,7 +350,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -350,7 +350,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculator.set_attn( self.speculator.set_attn(
self.model_state, self.model_state,
self.kv_cache_config, self.kv_cache_config,
self.attn_groups,
self.block_tables, self.block_tables,
) )
......
...@@ -5,15 +5,17 @@ from typing import Any ...@@ -5,15 +5,17 @@ from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata, build_attn_metadata,
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
init_attn_backend,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
...@@ -22,7 +24,6 @@ from vllm.v1.worker.gpu.model_states.interface import ModelState ...@@ -22,7 +24,6 @@ from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -87,18 +88,35 @@ class EagleSpeculator: ...@@ -87,18 +88,35 @@ class EagleSpeculator:
) )
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
target_attn_layer_names = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
self.model = load_eagle_model(target_model, self.vllm_config) self.model = load_eagle_model(target_model, self.vllm_config)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
self.draft_attn_layer_names = set(all_attn_layers) - set(
target_attn_layer_names
)
def set_attn( def set_attn(
self, self,
model_state: ModelState, model_state: ModelState,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables, block_tables: BlockTables,
) -> None: ) -> None:
self.model_state = model_state self.model_state = model_state
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.attn_groups = attn_groups _, self.attn_groups = init_attn_backend(
kv_cache_config,
self.vllm_config,
self.device,
active_layer_names=self.draft_attn_layer_names,
)
self.block_tables = block_tables self.block_tables = block_tables
@torch.inference_mode() @torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment