Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47e60509
Unverified
Commit
47e60509
authored
Apr 06, 2026
by
Lucas Wilkinson
Committed by
GitHub
Apr 06, 2026
Browse files
[Gemma4] Enable Fast Prefill Optimization (#38879)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
e69a2651
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
369 additions
and
47 deletions
+369
-47
vllm/model_executor/models/gemma4.py
vllm/model_executor/models/gemma4.py
+369
-47
No files found.
vllm/model_executor/models/gemma4.py
View file @
47e60509
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
"""Gemma 4 model implementation for vLLM."""
"""Gemma 4 model implementation for vLLM."""
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
dataclasses
import
replace
from
itertools
import
islice
from
itertools
import
islice
import
regex
as
re
import
regex
as
re
...
@@ -32,6 +33,7 @@ from vllm.distributed import (
...
@@ -32,6 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
...
@@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -56,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name
,
maybe_remap_kv_scale_name
,
)
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
from
.utils
import
(
...
@@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module):
...
@@ -636,7 +639,205 @@ class Gemma4DecoderLayer(nn.Module):
return
hidden_states
,
None
return
hidden_states
,
None
@
support_torch_compile
def
_run_decoder_layers
(
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run a slice of decoder layers with PLE extraction."""
residual
=
None
for
idx
,
layer
in
enumerate
(
decoder_layers
):
layer_idx
=
idx
+
layer_idx_start
layer_per_input
=
(
per_layer_inputs
[:,
layer_idx
,
:]
if
per_layer_inputs
is
not
None
else
None
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
per_layer_input
=
layer_per_input
,
**
kwargs
,
)
return
hidden_states
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4SelfDecoderLayers
(
nn
.
Module
):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
embed_tokens
:
VocabParallelEmbedding
,
normalizer
:
torch
.
Tensor
,
embed_tokens_per_layer
:
VocabParallelEmbedding
|
None
,
embed_scale_per_layer
:
torch
.
Tensor
|
None
,
per_layer_model_projection
:
ColumnParallelLinear
|
None
,
per_layer_projection_norm
:
RMSNorm
|
None
,
per_layer_input_scale
:
torch
.
Tensor
|
None
,
per_layer_projection_scale
:
torch
.
Tensor
|
None
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
config
=
_get_text_config
(
vllm_config
.
model_config
.
hf_config
)
self
.
config
=
config
self
.
hidden_size_per_layer_input
=
getattr
(
config
,
"hidden_size_per_layer_input"
,
0
)
self
.
vocab_size_per_layer_input
=
getattr
(
config
,
"vocab_size_per_layer_input"
,
config
.
vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self
.
embed_tokens
=
embed_tokens
self
.
normalizer
=
normalizer
self
.
embed_tokens_per_layer
=
embed_tokens_per_layer
self
.
embed_scale_per_layer
=
embed_scale_per_layer
self
.
per_layer_model_projection
=
per_layer_model_projection
self
.
per_layer_projection_norm
=
per_layer_projection_norm
self
.
per_layer_input_scale
=
per_layer_input_scale
self
.
per_layer_projection_scale
=
per_layer_projection_scale
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
normalizer
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if
self
.
embed_tokens_per_layer
is
None
:
return
None
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
vocab_size_per_layer_input
,
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
)
)
per_layer_embeds
=
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
per_layer_embeds
=
per_layer_embeds
*
self
.
embed_scale_per_layer
return
per_layer_embeds
.
reshape
(
*
input_ids
.
shape
,
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
def
project_per_layer_inputs
(
self
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if
self
.
per_layer_model_projection
is
None
:
return
None
per_layer_projection
=
self
.
per_layer_model_projection
(
inputs_embeds
)
per_layer_projection
=
per_layer_projection
*
self
.
per_layer_projection_scale
per_layer_projection
=
per_layer_projection
.
reshape
(
*
inputs_embeds
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
None
:
return
per_layer_projection
return
(
per_layer_projection
+
per_layer_inputs
)
*
self
.
per_layer_input_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_inputs
)
else
:
hidden_states
=
self
.
embed_input_ids
(
input_ids
)
per_layer_embeds
=
self
.
get_per_layer_inputs
(
input_ids
)
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_embeds
)
hidden_states
=
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
return
hidden_states
,
per_layer_inputs
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4CrossDecoderLayers
(
nn
.
Module
):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
return
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4Model
(
nn
.
Module
):
class
Gemma4Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -740,6 +941,75 @@ class Gemma4Model(nn.Module):
...
@@ -740,6 +941,75 @@ class Gemma4Model(nn.Module):
torch
.
tensor
(
config
.
hidden_size
**
0.5
),
torch
.
tensor
(
config
.
hidden_size
**
0.5
),
persistent
=
False
,
persistent
=
False
,
)
)
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx
=
config
.
num_hidden_layers
-
getattr
(
config
,
"num_kv_shared_layers"
,
0
)
from
vllm.compilation.backends
import
set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with
set_model_tag
(
"self_decoder"
):
self
.
self_decoder
=
Gemma4SelfDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_decoder"
,
decoder_layers
=
self
.
layers
[:
first_kv_shared_layer_idx
],
layer_idx_start
=
0
,
embed_tokens
=
self
.
embed_tokens
,
normalizer
=
self
.
normalizer
,
embed_tokens_per_layer
=
getattr
(
self
,
"embed_tokens_per_layer"
,
None
),
embed_scale_per_layer
=
getattr
(
self
,
"embed_scale_per_layer"
,
None
),
per_layer_model_projection
=
getattr
(
self
,
"per_layer_model_projection"
,
None
),
per_layer_projection_norm
=
getattr
(
self
,
"per_layer_projection_norm"
,
None
),
per_layer_input_scale
=
getattr
(
self
,
"per_layer_input_scale"
,
None
),
per_layer_projection_scale
=
getattr
(
self
,
"per_layer_projection_scale"
,
None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with
set_model_tag
(
"cross_decoder"
):
self
.
cross_decoder
=
Gemma4CrossDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.cross_decoder"
,
decoder_layers
=
self
.
layers
[
first_kv_shared_layer_idx
:],
layer_idx_start
=
first_kv_shared_layer_idx
,
)
self
.
fast_prefill_enabled
=
cache_config
.
kv_sharing_fast_prefill
if
self
.
fast_prefill_enabled
:
# Allocate static buffers for CUDAGraph
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
device
=
next
(
self
.
parameters
()).
device
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
hidden_size
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
if
(
self
.
hidden_size_per_layer_input
and
self
.
hidden_size_per_layer_input
>
0
):
self
.
per_layer_inputs
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
else
:
self
.
per_layer_inputs
=
None
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
# which differs from the standard (batch, hidden_size) shape,
...
@@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module):
...
@@ -776,47 +1046,22 @@ class Gemma4Model(nn.Module):
self
.
make_empty_intermediate_tensors
=
_make_empty_intermediate_tensors
self
.
make_empty_intermediate_tensors
=
_make_empty_intermediate_tensors
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
normalizer
return
self
.
self_decoder
.
embed_input_ids
(
input_ids
)
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Returns:
Per-layer embeddings (num_tokens, num_layers,
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
hidden_size_per_layer_input)
"""
"""
if
self
.
embed_tokens_per_layer
is
None
:
return
self
.
self_decoder
.
get_per_layer_inputs
(
input_ids
)
return
None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
vocab_size_per_layer_input
,
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds
=
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds
=
per_layer_embeds
*
self
.
embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds
=
per_layer_embeds
.
reshape
(
*
input_ids
.
shape
,
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
return
per_layer_embeds
def
project_per_layer_inputs
(
def
project_per_layer_inputs
(
self
,
self
,
inputs_embeds
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
Steps:
...
@@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module):
...
@@ -826,29 +1071,94 @@ class Gemma4Model(nn.Module):
4. Normalize with per_layer_projection_norm
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
"""
if
self
.
per_layer_model_projection
is
None
:
return
self
.
self_decoder
.
project_per_layer_inputs
(
return
None
inputs_embeds
,
per_layer_inputs
)
# Project from hidden_size to total_ple_dim
def
fast_prefill_forward
(
# Scaled projection: output = linear(input, weight) * scale
self
,
per_layer_projection
=
self
.
per_layer_model_projection
(
inputs_embeds
)
input_ids
:
torch
.
Tensor
|
None
,
per_layer_projection
=
per_layer_projection
*
self
.
per_layer_projection_scale
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
logits_indices_padded
,
num_logits_indices
=
None
,
None
attn_metadata
=
get_forward_context
().
attn_metadata
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
if
attn_metadata
is
not
None
:
per_layer_projection
=
per_layer_projection
.
reshape
(
assert
isinstance
(
attn_metadata
,
dict
)
*
inputs_embeds
.
shape
[:
-
1
],
layer_attn_metadata
=
attn_metadata
[
self
.
config
.
num_hidden_layers
,
self
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
self
.
hidden_size_per_layer_input
,
]
if
isinstance
(
layer_attn_metadata
,
KVSharingFastPrefillMetadata
):
logits_indices_padded
=
layer_attn_metadata
.
logits_indices_padded
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
batch_size
=
positions
.
size
(
0
)
self
.
positions
[:
batch_size
].
copy_
(
positions
)
self_decoder_hidden_states
,
per_layer_inputs
=
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
batch_size
],
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
)
# Normalize
if
logits_indices_padded
is
None
:
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
logits_indices_padded
=
torch
.
arange
(
batch_size
,
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
if
per_layer_inputs
is
None
:
# NOTE: Keep .clone() until fix in
return
per_layer_projection
# https://github.com/vllm-project/vllm/pull/22282
hidden_states
=
self_decoder_hidden_states
.
clone
()
# Combine: (projection + per_layer_inputs) * scale
num_padded
=
logits_indices_padded
.
size
(
0
)
return
(
per_layer_projection
+
per_layer_inputs
)
*
self
.
per_layer_input_scale
self
.
positions
[:
num_padded
].
copy_
(
positions
[
logits_indices_padded
])
self
.
hidden_states
[:
num_padded
].
copy_
(
self_decoder_hidden_states
[
logits_indices_padded
]
)
if
self
.
per_layer_inputs
is
not
None
and
per_layer_inputs
is
not
None
:
self
.
per_layer_inputs
[:
num_padded
].
copy_
(
per_layer_inputs
[
logits_indices_padded
]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context
=
get_forward_context
()
orig_batch_desc
=
forward_context
.
batch_descriptor
if
orig_batch_desc
is
not
None
:
forward_context
.
batch_descriptor
=
replace
(
orig_batch_desc
,
num_tokens
=
num_padded
)
cross_per_layer
=
(
self
.
per_layer_inputs
[:
num_padded
]
if
self
.
per_layer_inputs
is
not
None
else
None
)
cross_hidden_states
=
self
.
cross_decoder
(
self
.
positions
[:
num_padded
],
self
.
hidden_states
[:
num_padded
],
cross_per_layer
,
**
kwargs
,
)
# Restore the original batch_descriptor
forward_context
.
batch_descriptor
=
orig_batch_desc
if
num_logits_indices
is
not
None
:
assert
num_logits_indices
>
0
hidden_states
[
logits_indices_padded
[:
num_logits_indices
]]
=
(
cross_hidden_states
[:
num_logits_indices
]
)
else
:
hidden_states
=
cross_hidden_states
return
hidden_states
def
forward
(
def
forward
(
self
,
self
,
...
@@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module):
...
@@ -859,6 +1169,18 @@ class Gemma4Model(nn.Module):
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
self
.
fast_prefill_enabled
:
hidden_states
=
self
.
fast_prefill_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
# Normal (non-fast-prefill) path with PP support
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment