Unverified Commit ad510309 authored by Yong Hoon Shin's avatar Yong Hoon Shin Committed by GitHub
Browse files

Override attention metadata for fast prefill in some KV sharing setups (#21590)


Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
parent 366f6b3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import random
from typing import Optional, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors
from ...utils import fork_new_process_for_each_test
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')
# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.model.language_model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
return hidden_states
@pytest.fixture
def test_prompts():
"""
Adapted from tests/v1/e2e/test_spec_decode.py
"""
prompt_types = ["repeat", "sentence"]
# Setting higher num prompts increases the chance of numerics mismatch
# due to matrix multiplication numerics depending on batch dimension
num_prompts = 10
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""please repeat the word '{word}' 10 times."""
elif kind == "sentence":
prompt = f"""please give a ten-word sentence that
uses the word {word} at least once."""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append(prompt)
return prompts
@fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
# managing buffers for cudagraph
cudagraph_copy_inputs=True,
level=CompilationLevel.PIECEWISE
if not enforce_eager else CompilationLevel.NO_COMPILATION)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
)
ref_responses = llm.generate(test_prompts, sampling_params)
del llm
gc.collect()
torch.cuda.empty_cache()
llm = LLM(model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
kv_sharing_fast_prefill=True)
optimized_responses = llm.generate(test_prompts, sampling_params)
misses = 0
for ref_response, optimized_response in zip(ref_responses,
optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[
0].text:
misses += 1
assert misses == 0
...@@ -1795,6 +1795,16 @@ class CacheConfig: ...@@ -1795,6 +1795,16 @@ class CacheConfig:
num_cpu_blocks: Optional[int] = field(default=None, init=False) num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory.""" """The number of blocks to allocate for CPU memory."""
kv_sharing_fast_prefill: bool = False
"""This feature is work in progress and no prefill optimization takes place
with this flag enabled currently.
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata
necessary for implementating this optimization in some models (e.g. Gemma3n)
"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -1836,6 +1846,11 @@ class CacheConfig: ...@@ -1836,6 +1846,11 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got " "GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.") f"{self.gpu_memory_utilization}.")
if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")
return self return self
def _verify_cache_dtype(self) -> None: def _verify_cache_dtype(self) -> None:
......
...@@ -445,6 +445,9 @@ class EngineArgs: ...@@ -445,6 +445,9 @@ class EngineArgs:
# DEPRECATED # DEPRECATED
enable_prompt_adapter: bool = False enable_prompt_adapter: bool = False
kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
...@@ -697,6 +700,8 @@ class EngineArgs: ...@@ -697,6 +700,8 @@ class EngineArgs:
**cache_kwargs["cpu_offload_gb"]) **cache_kwargs["cpu_offload_gb"])
cache_group.add_argument("--calculate-kv-scales", cache_group.add_argument("--calculate-kv-scales",
**cache_kwargs["calculate_kv_scales"]) **cache_kwargs["calculate_kv_scales"])
cache_group.add_argument("--kv-sharing-fast-prefill",
**cache_kwargs["kv_sharing_fast_prefill"])
# Multimodal related configs # Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig) multimodal_kwargs = get_kwargs(MultiModalConfig)
...@@ -1069,6 +1074,7 @@ class EngineArgs: ...@@ -1069,6 +1074,7 @@ class EngineArgs:
prefix_caching_hash_algo=self.prefix_caching_hash_algo, prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb, cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales, calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
) )
# Get the current placement group if Ray is initialized and # Get the current placement group if Ray is initialized and
......
...@@ -793,6 +793,7 @@ class Gemma3nForConditionalGeneration(nn.Module): ...@@ -793,6 +793,7 @@ class Gemma3nForConditionalGeneration(nn.Module):
del lora_config # Unused. del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.cache_config = vllm_config.cache_config
self.model = Gemma3nModel(vllm_config=vllm_config, self.model = Gemma3nModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import abc import abc
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
import numpy as np import numpy as np
import torch import torch
...@@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills( ...@@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills(
modified_batch = True modified_batch = True
return modified_batch return modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),
]
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name: str = name_prefix + metadata_cls.__name__ # type: ignore
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
return Wrapped
def make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls: Any, ) -> Any:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return subclass_attention_metadata(
name_prefix="KVSharingFastPrefill",
metadata_cls=metadata_cls,
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import gc import gc
import time import time
from contextlib import contextmanager from contextlib import contextmanager
...@@ -47,6 +48,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, ...@@ -47,6 +48,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches) make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
...@@ -320,6 +322,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -320,6 +322,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# means this layer will perform attention using the keys and values # means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`. # from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {} self.shared_kv_cache_layers: dict[str, str] = {}
self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()
self.kv_sharing_fast_prefill_logits_indices = None
if self.cache_config.kv_sharing_fast_prefill:
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device)
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
...@@ -735,6 +743,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -735,6 +743,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata = None spec_decode_common_attn_metadata = None
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
assert self.kv_sharing_fast_prefill_logits_indices is not None
num_logits = logits_indices.shape[0]
assert num_logits > 0
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
logits_indices)
# There might have leftover indices in logits_indices[num_logits:]
# from previous iterations, whose values may be greater than the
# batch size in the current iteration. To ensure indices are always
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.use_cuda_graph
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(
num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = (
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
)
attn_metadata: dict[str, Any] = {} attn_metadata: dict[str, Any] = {}
# Prepare encoder attention metadata separately # Prepare encoder attention metadata separately
...@@ -806,7 +863,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -806,7 +863,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
)) ))
fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
if (self.cache_config.kv_sharing_fast_prefill and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
# Hack for now to fix chunked local attention + no hybrid kv cache # Hack for now to fix chunked local attention + no hybrid kv cache
...@@ -838,30 +916,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -838,30 +916,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
b.can_run_in_cudagraph(common_attn_metadata) b.can_run_in_cudagraph(common_attn_metadata)
for b in self.attn_metadata_builders) for b in self.attn_metadata_builders)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
...@@ -1433,6 +1487,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1433,6 +1487,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_metadata, num_scheduled_tokens_np, spec_decode_metadata, num_scheduled_tokens_np,
spec_decode_common_attn_metadata) = ( spec_decode_common_attn_metadata) = (
self._prepare_inputs(scheduler_output)) self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
...@@ -2814,6 +2869,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2814,6 +2869,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config.kv_cache_groups, kv_cache_config.kv_cache_groups,
kv_caches, kv_caches,
) )
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
# Iterate in reversed order and add layers that re-use KV cache
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
for layer_name in reversed(attn_layers):
if layer_name in self.shared_kv_cache_layers:
self.kv_sharing_fast_prefill_eligible_layers.add(
layer_name)
else:
break
bind_kv_cache(kv_caches, bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context, self.compilation_config.static_forward_context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment