Commit 6f3dc4d0 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by khluu
Browse files

[Gemma4] Enable Fast Prefill Optimization (#38879)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit 47e60509)
(cherry picked from commit fc29ef13a0dff7c4fa14590655a160f20fb7124d)
parent a6ac49a3
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
"""Gemma 4 model implementation for vLLM.""" """Gemma 4 model implementation for vLLM."""
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import replace
from itertools import islice from itertools import islice
import regex as re import regex as re
...@@ -32,6 +33,7 @@ from vllm.distributed import ( ...@@ -32,6 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
...@@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
...@@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module): ...@@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module):
return hidden_states, None return hidden_states, None
@support_torch_compile def _run_decoder_layers(
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""Run a slice of decoder layers with PLE extraction."""
residual = None
for idx, layer in enumerate(decoder_layers):
layer_idx = idx + layer_idx_start
layer_per_input = (
per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None
)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
return hidden_states
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4SelfDecoderLayers(nn.Module):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
embed_tokens: VocabParallelEmbedding,
normalizer: torch.Tensor,
embed_tokens_per_layer: VocabParallelEmbedding | None,
embed_scale_per_layer: torch.Tensor | None,
per_layer_model_projection: ColumnParallelLinear | None,
per_layer_projection_norm: RMSNorm | None,
per_layer_input_scale: torch.Tensor | None,
per_layer_projection_scale: torch.Tensor | None,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
config = _get_text_config(vllm_config.model_config.hf_config)
self.config = config
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self.embed_tokens = embed_tokens
self.normalizer = normalizer
self.embed_tokens_per_layer = embed_tokens_per_layer
self.embed_scale_per_layer = embed_scale_per_layer
self.per_layer_model_projection = per_layer_model_projection
self.per_layer_projection_norm = per_layer_projection_norm
self.per_layer_input_scale = per_layer_input_scale
self.per_layer_projection_scale = per_layer_projection_scale
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
return per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if inputs_embeds is not None:
hidden_states = inputs_embeds
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
hidden_states = _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
return hidden_states, per_layer_inputs
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4CrossDecoderLayers(nn.Module):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
return _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4Model(nn.Module): class Gemma4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -740,6 +941,75 @@ class Gemma4Model(nn.Module): ...@@ -740,6 +941,75 @@ class Gemma4Model(nn.Module):
torch.tensor(config.hidden_size**0.5), torch.tensor(config.hidden_size**0.5),
persistent=False, persistent=False,
) )
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
config, "num_kv_shared_layers", 0
)
from vllm.compilation.backends import set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with set_model_tag("self_decoder"):
self.self_decoder = Gemma4SelfDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
embed_tokens=self.embed_tokens,
normalizer=self.normalizer,
embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None),
embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None),
per_layer_model_projection=getattr(
self, "per_layer_model_projection", None
),
per_layer_projection_norm=getattr(
self, "per_layer_projection_norm", None
),
per_layer_input_scale=getattr(self, "per_layer_input_scale", None),
per_layer_projection_scale=getattr(
self, "per_layer_projection_scale", None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma4CrossDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(
max_num_tokens, dtype=torch.int64, device=device
)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
if (
self.hidden_size_per_layer_input
and self.hidden_size_per_layer_input > 0
):
self.per_layer_inputs = torch.zeros(
(
max_num_tokens,
config.num_hidden_layers,
self.hidden_size_per_layer_input,
),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
else:
self.per_layer_inputs = None
# Custom factory that includes per_layer_inputs for PLE-enabled PP. # Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim), # per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape, # which differs from the standard (batch, hidden_size) shape,
...@@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module): ...@@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module):
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer return self.self_decoder.embed_input_ids(input_ids)
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor: def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer. """Get per-layer embeddings from embed_tokens_per_layer.
Returns: Returns:
Per-layer embeddings (num_tokens, num_layers, Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input) hidden_size_per_layer_input)
""" """
if self.embed_tokens_per_layer is None: return self.self_decoder.get_per_layer_inputs(input_ids)
return None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds = per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
return per_layer_embeds
def project_per_layer_inputs( def project_per_layer_inputs(
self, self,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None, per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs. """Project inputs_embeds and combine with per_layer_inputs.
Steps: Steps:
...@@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module): ...@@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module):
4. Normalize with per_layer_projection_norm 4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
""" """
if self.per_layer_model_projection is None: return self.self_decoder.project_per_layer_inputs(
return None inputs_embeds, per_layer_inputs
)
# Project from hidden_size to total_ple_dim def fast_prefill_forward(
# Scaled projection: output = linear(input, weight) * scale self,
per_layer_projection = self.per_layer_model_projection(inputs_embeds) input_ids: torch.Tensor | None,
per_layer_projection = per_layer_projection * self.per_layer_projection_scale positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) if attn_metadata is not None:
per_layer_projection = per_layer_projection.reshape( assert isinstance(attn_metadata, dict)
*inputs_embeds.shape[:-1], layer_attn_metadata = attn_metadata[
self.config.num_hidden_layers, self.layers[-1].self_attn.attn.layer_name
self.hidden_size_per_layer_input, ]
if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata):
logits_indices_padded = layer_attn_metadata.logits_indices_padded
num_logits_indices = layer_attn_metadata.num_logits_indices
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
) )
# Normalize if logits_indices_padded is None:
per_layer_projection = self.per_layer_projection_norm(per_layer_projection) logits_indices_padded = torch.arange(
batch_size,
dtype=positions.dtype,
device=positions.device,
)
if per_layer_inputs is None: # NOTE: Keep .clone() until fix in
return per_layer_projection # https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
# Combine: (projection + per_layer_inputs) * scale num_padded = logits_indices_padded.size(0)
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale self.positions[:num_padded].copy_(positions[logits_indices_padded])
self.hidden_states[:num_padded].copy_(
self_decoder_hidden_states[logits_indices_padded]
)
if self.per_layer_inputs is not None and per_layer_inputs is not None:
self.per_layer_inputs[:num_padded].copy_(
per_layer_inputs[logits_indices_padded]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context = get_forward_context()
orig_batch_desc = forward_context.batch_descriptor
if orig_batch_desc is not None:
forward_context.batch_descriptor = replace(
orig_batch_desc, num_tokens=num_padded
)
cross_per_layer = (
self.per_layer_inputs[:num_padded]
if self.per_layer_inputs is not None
else None
)
cross_hidden_states = self.cross_decoder(
self.positions[:num_padded],
self.hidden_states[:num_padded],
cross_per_layer,
**kwargs,
)
# Restore the original batch_descriptor
forward_context.batch_descriptor = orig_batch_desc
if num_logits_indices is not None:
assert num_logits_indices > 0
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_hidden_states[:num_logits_indices]
)
else:
hidden_states = cross_hidden_states
return hidden_states
def forward( def forward(
self, self,
...@@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module): ...@@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module):
per_layer_inputs: torch.Tensor | None = None, per_layer_inputs: torch.Tensor | None = None,
**kwargs, **kwargs,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return hidden_states
# Normal (non-fast-prefill) path with PP support
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment