Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47e60509
Unverified
Commit
47e60509
authored
Apr 06, 2026
by
Lucas Wilkinson
Committed by
GitHub
Apr 06, 2026
Browse files
[Gemma4] Enable Fast Prefill Optimization (#38879)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
e69a2651
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
369 additions
and
47 deletions
+369
-47
vllm/model_executor/models/gemma4.py
vllm/model_executor/models/gemma4.py
+369
-47
No files found.
vllm/model_executor/models/gemma4.py
View file @
47e60509
...
...
@@ -19,6 +19,7 @@
"""Gemma 4 model implementation for vLLM."""
from
collections.abc
import
Iterable
from
dataclasses
import
replace
from
itertools
import
islice
import
regex
as
re
...
...
@@ -32,6 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
Attention
...
...
@@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name
,
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
...
...
@@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module):
return
hidden_states
,
None
@
support_torch_compile
def
_run_decoder_layers
(
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run a slice of decoder layers with PLE extraction."""
residual
=
None
for
idx
,
layer
in
enumerate
(
decoder_layers
):
layer_idx
=
idx
+
layer_idx_start
layer_per_input
=
(
per_layer_inputs
[:,
layer_idx
,
:]
if
per_layer_inputs
is
not
None
else
None
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
per_layer_input
=
layer_per_input
,
**
kwargs
,
)
return
hidden_states
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4SelfDecoderLayers
(
nn
.
Module
):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
embed_tokens
:
VocabParallelEmbedding
,
normalizer
:
torch
.
Tensor
,
embed_tokens_per_layer
:
VocabParallelEmbedding
|
None
,
embed_scale_per_layer
:
torch
.
Tensor
|
None
,
per_layer_model_projection
:
ColumnParallelLinear
|
None
,
per_layer_projection_norm
:
RMSNorm
|
None
,
per_layer_input_scale
:
torch
.
Tensor
|
None
,
per_layer_projection_scale
:
torch
.
Tensor
|
None
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
config
=
_get_text_config
(
vllm_config
.
model_config
.
hf_config
)
self
.
config
=
config
self
.
hidden_size_per_layer_input
=
getattr
(
config
,
"hidden_size_per_layer_input"
,
0
)
self
.
vocab_size_per_layer_input
=
getattr
(
config
,
"vocab_size_per_layer_input"
,
config
.
vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self
.
embed_tokens
=
embed_tokens
self
.
normalizer
=
normalizer
self
.
embed_tokens_per_layer
=
embed_tokens_per_layer
self
.
embed_scale_per_layer
=
embed_scale_per_layer
self
.
per_layer_model_projection
=
per_layer_model_projection
self
.
per_layer_projection_norm
=
per_layer_projection_norm
self
.
per_layer_input_scale
=
per_layer_input_scale
self
.
per_layer_projection_scale
=
per_layer_projection_scale
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
normalizer
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if
self
.
embed_tokens_per_layer
is
None
:
return
None
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
vocab_size_per_layer_input
,
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
)
)
per_layer_embeds
=
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
per_layer_embeds
=
per_layer_embeds
*
self
.
embed_scale_per_layer
return
per_layer_embeds
.
reshape
(
*
input_ids
.
shape
,
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
def
project_per_layer_inputs
(
self
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if
self
.
per_layer_model_projection
is
None
:
return
None
per_layer_projection
=
self
.
per_layer_model_projection
(
inputs_embeds
)
per_layer_projection
=
per_layer_projection
*
self
.
per_layer_projection_scale
per_layer_projection
=
per_layer_projection
.
reshape
(
*
inputs_embeds
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
None
:
return
per_layer_projection
return
(
per_layer_projection
+
per_layer_inputs
)
*
self
.
per_layer_input_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_inputs
)
else
:
hidden_states
=
self
.
embed_input_ids
(
input_ids
)
per_layer_embeds
=
self
.
get_per_layer_inputs
(
input_ids
)
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_embeds
)
hidden_states
=
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
return
hidden_states
,
per_layer_inputs
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4CrossDecoderLayers
(
nn
.
Module
):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
return
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -740,6 +941,75 @@ class Gemma4Model(nn.Module):
torch
.
tensor
(
config
.
hidden_size
**
0.5
),
persistent
=
False
,
)
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx
=
config
.
num_hidden_layers
-
getattr
(
config
,
"num_kv_shared_layers"
,
0
)
from
vllm.compilation.backends
import
set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with
set_model_tag
(
"self_decoder"
):
self
.
self_decoder
=
Gemma4SelfDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_decoder"
,
decoder_layers
=
self
.
layers
[:
first_kv_shared_layer_idx
],
layer_idx_start
=
0
,
embed_tokens
=
self
.
embed_tokens
,
normalizer
=
self
.
normalizer
,
embed_tokens_per_layer
=
getattr
(
self
,
"embed_tokens_per_layer"
,
None
),
embed_scale_per_layer
=
getattr
(
self
,
"embed_scale_per_layer"
,
None
),
per_layer_model_projection
=
getattr
(
self
,
"per_layer_model_projection"
,
None
),
per_layer_projection_norm
=
getattr
(
self
,
"per_layer_projection_norm"
,
None
),
per_layer_input_scale
=
getattr
(
self
,
"per_layer_input_scale"
,
None
),
per_layer_projection_scale
=
getattr
(
self
,
"per_layer_projection_scale"
,
None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with
set_model_tag
(
"cross_decoder"
):
self
.
cross_decoder
=
Gemma4CrossDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.cross_decoder"
,
decoder_layers
=
self
.
layers
[
first_kv_shared_layer_idx
:],
layer_idx_start
=
first_kv_shared_layer_idx
,
)
self
.
fast_prefill_enabled
=
cache_config
.
kv_sharing_fast_prefill
if
self
.
fast_prefill_enabled
:
# Allocate static buffers for CUDAGraph
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
device
=
next
(
self
.
parameters
()).
device
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
hidden_size
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
if
(
self
.
hidden_size_per_layer_input
and
self
.
hidden_size_per_layer_input
>
0
):
self
.
per_layer_inputs
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
else
:
self
.
per_layer_inputs
=
None
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
...
...
@@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module):
self
.
make_empty_intermediate_tensors
=
_make_empty_intermediate_tensors
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
normalizer
return
self
.
self_decoder
.
embed_input_ids
(
input_ids
)
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if
self
.
embed_tokens_per_layer
is
None
:
return
None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
vocab_size_per_layer_input
,
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds
=
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds
=
per_layer_embeds
*
self
.
embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds
=
per_layer_embeds
.
reshape
(
*
input_ids
.
shape
,
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
return
per_layer_embeds
return
self
.
self_decoder
.
get_per_layer_inputs
(
input_ids
)
def
project_per_layer_inputs
(
self
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
...
...
@@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module):
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if
self
.
per_layer_model_projection
is
None
:
return
None
return
self
.
self_decoder
.
project_per_layer_inputs
(
inputs_embeds
,
per_layer_inputs
)
# Project from hidden_size to total_ple_dim
# Scaled projection: output = linear(input, weight) * scale
per_layer_projection
=
self
.
per_layer_model_projection
(
inputs_embeds
)
per_layer_projection
=
per_layer_projection
*
self
.
per_layer_projection_scale
def
fast_prefill_forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
logits_indices_padded
,
num_logits_indices
=
None
,
None
attn_metadata
=
get_forward_context
().
attn_metadata
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_projection
=
per_layer_projection
.
reshape
(
*
inputs_embeds
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
layer_attn_metadata
=
attn_metadata
[
self
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
if
isinstance
(
layer_attn_metadata
,
KVSharingFastPrefillMetadata
):
logits_indices_padded
=
layer_attn_metadata
.
logits_indices_padded
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
batch_size
=
positions
.
size
(
0
)
self
.
positions
[:
batch_size
].
copy_
(
positions
)
self_decoder_hidden_states
,
per_layer_inputs
=
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
batch_size
],
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
# Normalize
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
logits_indices_padded
is
None
:
logits_indices_padded
=
torch
.
arange
(
batch_size
,
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
if
per_layer_inputs
is
None
:
return
per_layer_projection
# NOTE: Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states
=
self_decoder_hidden_states
.
clone
()
# Combine: (projection + per_layer_inputs) * scale
return
(
per_layer_projection
+
per_layer_inputs
)
*
self
.
per_layer_input_scale
num_padded
=
logits_indices_padded
.
size
(
0
)
self
.
positions
[:
num_padded
].
copy_
(
positions
[
logits_indices_padded
])
self
.
hidden_states
[:
num_padded
].
copy_
(
self_decoder_hidden_states
[
logits_indices_padded
]
)
if
self
.
per_layer_inputs
is
not
None
and
per_layer_inputs
is
not
None
:
self
.
per_layer_inputs
[:
num_padded
].
copy_
(
per_layer_inputs
[
logits_indices_padded
]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context
=
get_forward_context
()
orig_batch_desc
=
forward_context
.
batch_descriptor
if
orig_batch_desc
is
not
None
:
forward_context
.
batch_descriptor
=
replace
(
orig_batch_desc
,
num_tokens
=
num_padded
)
cross_per_layer
=
(
self
.
per_layer_inputs
[:
num_padded
]
if
self
.
per_layer_inputs
is
not
None
else
None
)
cross_hidden_states
=
self
.
cross_decoder
(
self
.
positions
[:
num_padded
],
self
.
hidden_states
[:
num_padded
],
cross_per_layer
,
**
kwargs
,
)
# Restore the original batch_descriptor
forward_context
.
batch_descriptor
=
orig_batch_desc
if
num_logits_indices
is
not
None
:
assert
num_logits_indices
>
0
hidden_states
[
logits_indices_padded
[:
num_logits_indices
]]
=
(
cross_hidden_states
[:
num_logits_indices
]
)
else
:
hidden_states
=
cross_hidden_states
return
hidden_states
def
forward
(
self
,
...
...
@@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module):
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
self
.
fast_prefill_enabled
:
hidden_states
=
self
.
fast_prefill_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
# Normal (non-fast-prefill) path with PP support
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment