Unverified Commit cb293f6a authored by Yong Hoon Shin's avatar Yong Hoon Shin Committed by GitHub
Browse files

[V1] Enable prefill optimization for Gemma3n (#22628)


Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
parent 7ffbf272
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Optional, Union
import pytest
import torch
......@@ -10,12 +9,6 @@ import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n_mm import (
Gemma3nForConditionalGeneration)
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors
from ...utils import fork_new_process_for_each_test
......@@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
SEED = 42
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds,
**kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.language_model.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')
# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.language_model.model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
return hidden_states
@pytest.fixture
def test_prompts():
"""
......@@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
......
......@@ -145,12 +145,19 @@ class CacheConfig:
self._verify_cache_dtype()
self._verify_prefix_caching()
self._verify_kv_sharing_fast_prefill()
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_kv_sharing_fast_prefill(self) -> None:
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
raise NotImplementedError(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently.")
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
......@@ -162,11 +169,6 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")
return self
def _verify_cache_dtype(self) -> None:
......
This diff is collapsed.
......@@ -620,7 +620,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
if input_ids is not None:
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings(
input_ids)
per_layer_inputs = per_layer_inputs.reshape(
-1, self.config.text_config.num_hidden_layers,
......
......@@ -4,11 +4,13 @@ import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
from dataclasses import dataclass, fields, make_dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
TypeVar)
import numpy as np
import torch
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
......@@ -19,7 +21,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
......@@ -65,6 +68,10 @@ class CommonAttentionMetadata:
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None
@dataclass
class UbatchSlice:
......@@ -542,6 +549,69 @@ def make_local_attention_virtual_batches(
)
def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
if common_attn_metadata.max_query_len == 1:
# All requests are decode (assume 1 token for now)
# Skip computing fast prefill path
return common_attn_metadata
assert common_attn_metadata.logits_indices_padded is not None
assert common_attn_metadata.num_logits_indices is not None
logits_indices_padded = common_attn_metadata.logits_indices_padded
num_logits_indices = common_attn_metadata.num_logits_indices
# Get rid of CUDAGraph padding, if any
logits_indices = logits_indices_padded[:num_logits_indices]
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
# query_start_loc: [0, 15, 20, 28]
# seq_lens: [41, 31, 40]
# Find how many decode indices belong to each request
# request_ids: [0, 1, 1, 2]
request_ids = torch.bucketize(logits_indices,
query_start_loc[1:],
right=True)
# Figure out how many tokens are in each request
# num_decode_tokens: [1, 2, 1]
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
# Calculate new query_start_loc with tokens in generation_indices
# decode_query_start_loc: [0, 1, 3, 4]
decode_query_start_loc = torch.empty(num_reqs + 1,
device=query_start_loc.device,
dtype=query_start_loc.dtype)
decode_query_start_loc[0] = 0
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
decode_max_query_len = int(num_decode_tokens.max().item())
total_num_decode_tokens = int(num_decode_tokens.sum().item())
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=decode_query_start_loc,
query_start_loc_cpu=decode_query_start_loc.to("cpu",
non_blocking=True),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_decode_tokens,
max_query_len=decode_max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)
return common_attn_metadata
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
......@@ -679,13 +749,56 @@ def subclass_attention_metadata(
return Wrapped
def make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls: Any, ) -> Any:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return subclass_attention_metadata(
name_prefix="KVSharingFastPrefill",
metadata_cls=metadata_cls,
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
)
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor
num_logits_indices: int
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls()
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata =\
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
metadata = super().build(common_prefix_len,
new_common_attn_metadata, fast_build)
class KVSharingFastPrefillAttentionMetadata(
metadata.__class__, # type: ignore
KVSharingFastPrefillMetadata):
def __init__(self, metadata, common_attn_metadata):
# Shallow copy all fields in metadata cls
for field in fields(metadata.__class__):
setattr(self, field.name,
getattr(metadata, field.name))
# Set additional fields that will be used in model code
assert (common_attn_metadata.logits_indices_padded
is not None
and common_attn_metadata.num_logits_indices
is not None)
self.logits_indices_padded = \
common_attn_metadata.logits_indices_padded
self.num_logits_indices = \
common_attn_metadata.num_logits_indices
return KVSharingFastPrefillAttentionMetadata(
metadata, common_attn_metadata)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=FastPrefillAttentionBuilder)
return attn_backend
......@@ -335,6 +335,13 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller.
"""
if (self.vllm_config.cache_config.kv_sharing_fast_prefill
and sampling_params.prompt_logprobs):
raise ValueError(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs")
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import gc
import itertools
import time
......@@ -58,7 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
supports_dynamo)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
create_fast_prefill_custom_backend,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec,
......@@ -84,9 +83,10 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
from .utils import (AttentionGroup, MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING:
import xgrammar as xgr
......@@ -860,6 +860,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_seq_len=max_seq_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
causal=True,
)
......@@ -884,28 +886,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata=common_attn_metadata,
))
fast_prefill_metadata = attn_metadata_i
if (self.cache_config.kv_sharing_fast_prefill
and self.kv_sharing_fast_prefill_eligible_layers):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type = (
make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls=type(attn_metadata_i), ))
fast_prefill_metadata = fast_prefill_metadata_type(
**dataclasses.asdict(attn_metadata_i),
logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0),
)
for layer_name in attn_group.layer_names:
if (self.cache_config.kv_sharing_fast_prefill
and layer_name
in self.kv_sharing_fast_prefill_eligible_layers):
attn_metadata[layer_name] = fast_prefill_metadata
continue
attn_metadata[layer_name] = attn_metadata_i
# Hot-Swap lora model
......@@ -1484,6 +1465,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests "
"need prompt logprobs")
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
......@@ -2742,6 +2729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# layer.
for layer_name in layer_names:
attn_backend = layers[layer_name].get_attn_backend()
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend(
"FastPrefill",
attn_backend,
)
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
......@@ -3074,20 +3068,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if self.shared_kv_cache_layers:
initialize_kv_cache_for_kv_sharing(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
kv_caches,
self.attn_groups,
self.runner_only_attn_layers,
)
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
):
logger.debug("%s reuses KV cache of %s", layer_name,
target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
self, kv_cache_config: KVCacheConfig) -> None:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if not self.shared_kv_cache_layers:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
self.runner_only_attn_layers,
)
if self.cache_config.kv_sharing_fast_prefill:
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
# similar KV sharing setups, only the layers that generate KV caches
# are involved in the prefill phase, enabling prefill to early exit.
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
# Iterate in reversed order and add layers that re-use KV cache
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
for layer_name in reversed(attn_layers):
if layer_name in self.shared_kv_cache_layers:
self.kv_sharing_fast_prefill_eligible_layers.add(
......@@ -3095,11 +3109,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
break
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
......@@ -3111,6 +3120,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
......
......@@ -55,9 +55,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (MultiModalBudget, bind_kv_cache,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs)
from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache, sanity_check_mm_encoder_outputs)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -1599,6 +1598,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.encoder_cache.clear()
gc.collect()
def maybe_setup_cross_layer_kv_sharing(
self,
kv_caches: dict[str, torch.Tensor],
kv_cache_config: KVCacheConfig,
) -> None:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if not self.shared_kv_cache_layers:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
)
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
):
logger.debug("%s reuses KV cache of %s", layer_name,
target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
......@@ -1664,14 +1687,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
raise NotImplementedError
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if self.shared_kv_cache_layers:
initialize_kv_cache_for_kv_sharing(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
kv_caches,
)
# Set up cross-layer KV cache sharing if needed
self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
bind_kv_cache(
kv_caches,
......
......@@ -203,12 +203,9 @@ def gather_mm_placeholders(
return placeholders[is_embed]
def initialize_kv_cache_for_kv_sharing(
def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
kv_caches: dict[str, torch.Tensor],
# Optional for now to avoid breaking TPU
attn_groups: Optional[list[list[AttentionGroup]]] = None,
runner_only_attn_layers: Optional[set[str]] = None,
) -> None:
"""
......@@ -223,38 +220,15 @@ def initialize_kv_cache_for_kv_sharing(
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
kv_caches: The allocated kv_caches with layer names as keys.
Note that layers in shared_kv_cache_layers.keys() are not
originally included as it only contains layers which have its own
KV cache allocation.
attn_groups: Optional list of attention groups. Layers in the same KV
cache group may be placed in different attention groups if they
have different attention backends. Currently only provided by
GPU model runner.
"""
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
if attn_groups:
for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
for attn_group_idx, attn_group in enumerate(kv_attn_groups):
for layer_name in attn_group.layer_names:
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
attn_group_idx)
else:
for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
# attn group idx default to 0 if not provided
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
for kv_cache_group in kv_cache_groups:
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group[layer_name] = kv_cache_group
for layer_name, target_layer_name in shared_kv_cache_layers.items():
kv_caches[layer_name] = kv_caches[target_layer_name]
kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)
if attn_groups:
attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
layer_name)
tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
tgt_kv_cache_group.layer_names.append(layer_name)
if runner_only_attn_layers is not None:
runner_only_attn_layers.add(layer_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment