Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
77d90699
Unverified
Commit
77d90699
authored
Sep 23, 2025
by
Yong Hoon Shin
Committed by
GitHub
Sep 23, 2025
Browse files
[KV sharing] Re-land Gemma3n model changes from #22628 (#24357)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
359d2930
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
344 additions
and
58 deletions
+344
-58
vllm/model_executor/models/gemma3n.py
vllm/model_executor/models/gemma3n.py
+344
-58
No files found.
vllm/model_executor/models/gemma3n.py
View file @
77d90699
...
...
@@ -26,6 +26,7 @@ from vllm.attention import Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
(
_ACTIVATION_REGISTRY
,
GeluAndMul
,
...
...
@@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
SupportsQuant
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
...
...
@@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
logger
=
init_logger
(
__name__
)
EPS
=
torch
.
tensor
(
torch
.
finfo
().
min
)
class
Gemma3nAltUp
(
nn
.
Module
):
"""Alternating updates (Altup)
...
...
@@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module):
return
corrected_predictions
@
support_torch_compile
class
Gemma3nTextModel
(
nn
.
Module
,
SupportsQuant
):
# This enables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nSelfDecoder
(
nn
.
Module
):
"""
Includes altup embedding and self decoder layers
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma3nDecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
...
...
@@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
prefix
=
f
"
{
prefix
}
.altup_projections.
{
idx
-
1
}
"
,
)
for
idx
in
range
(
1
,
self
.
config
.
altup_num_inputs
)
])
self
.
altup_unembed_projections
=
nn
.
ModuleList
([
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
gather_output
=
True
,
return_bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.altup_unembed_projections.
{
idx
-
1
}
"
,
)
for
idx
in
range
(
1
,
self
.
config
.
altup_num_inputs
)
])
# Transformer blocks.
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Gemma3nDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
eps
=
torch
.
tensor
(
torch
.
finfo
().
min
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
def
get_per_layer_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
return
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
*
self
.
embed_scale_per_layer
def
forward
(
def
get_per_layer_inputs
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
inputs_embeds
is
not
None
:
hidden_states_0
=
inputs_embeds
else
:
hidden_states_0
=
self
.
get_input_embeddings
(
input_ids
)
hidden_states_0
:
torch
.
Tensor
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
per_layer_projection
=
self
.
per_layer_model_projection
(
hidden_states_0
)
per_layer_projection
=
per_layer_projection
.
reshape
(
*
hidden_states_0
.
shape
[:
-
1
],
...
...
@@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
not
None
:
# Profiling run does not compute per_layer_inputs
per_layer_inputs
=
per_layer_projection
+
per_layer_inputs
per_layer_inputs
*=
self
.
per_layer_input_scale
else
:
per_layer_inputs
=
per_layer_projection
return
per_layer_inputs
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
def
altup_embed
(
self
,
hidden_states_0
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Altup embed.
hidden_states
=
[
hidden_states_0
]
*
self
.
config
.
altup_num_inputs
target_magnitude
=
torch
.
mean
(
hidden_states_0
**
2
,
dim
=-
1
,
...
...
@@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=
0
)
new_magnitude
,
EPS
)
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=-
1
)
return
hidden_states
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
inputs_embeds
is
not
None
:
hidden_states_0
=
inputs_embeds
else
:
hidden_states_0
=
self
.
get_input_embeddings
(
input_ids
)
adjusted_per_layer_inputs
=
self
.
get_per_layer_inputs
(
hidden_states_0
,
per_layer_inputs
)
hidden_states
=
self
.
altup_embed
(
hidden_states_0
)
# [altnum_inputs, num_tokens, hidden_size]
hidden_states
=
hidden_states
.
permute
(
2
,
0
,
1
)
for
idx
,
layer
in
enumerate
(
self
.
decoder_layers
):
layer_idx
=
idx
+
self
.
layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_input
=
adjusted_per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states
=
hidden_states
.
permute
(
1
,
2
,
0
)
return
hidden_states
,
adjusted_per_layer_inputs
# This enables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nCrossDecoder
(
nn
.
Module
):
"""
Cross-decoder layers
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma3nDecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
# Transformer blocks.
for
layer_idx
,
layer
in
enumerate
(
self
.
layers
):
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
# [altnum_inputs, num_tokens, hidden_size]
hidden_states
=
hidden_states
.
permute
(
2
,
0
,
1
)
for
idx
,
layer
in
enumerate
(
self
.
decoder_layers
):
layer_idx
=
idx
+
self
.
layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
...
...
@@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
per_layer_input
=
per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states
=
hidden_states
.
permute
(
1
,
2
,
0
)
return
hidden_states
# This disables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nTextModel
(
nn
.
Module
,
SupportsQuant
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
altup_unembed_projections
=
nn
.
ModuleList
([
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
gather_output
=
True
,
return_bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.altup_unembed_projections.
{
idx
-
1
}
"
,
)
for
idx
in
range
(
1
,
self
.
config
.
altup_num_inputs
)
])
# Allocate config.num_kv_shared_layers layers for self-decoder
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Gemma3nDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
first_kv_shared_layer_idx
=
(
config
.
num_hidden_layers
-
config
.
num_kv_shared_layers
)
# NOTE(sarckk): importing this top level seems to cause issues
# during running of tests.
from
vllm.compilation.backends
import
set_model_tag
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
with
set_model_tag
(
"self_decoder"
):
self
.
self_decoder
=
Gemma3nSelfDecoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_decoder"
,
decoder_layers
=
self
.
layers
[:
first_kv_shared_layer_idx
],
layer_idx_start
=
0
,
)
# Layer idx 20-30 are cross-decoder layers in YOCO
with
set_model_tag
(
"cross_decoder"
):
self
.
cross_decoder
=
Gemma3nCrossDecoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.cross_decoder"
,
decoder_layers
=
self
.
layers
[
first_kv_shared_layer_idx
:],
layer_idx_start
=
first_kv_shared_layer_idx
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
fast_prefill_enabled
=
cache_config
.
kv_sharing_fast_prefill
if
self
.
fast_prefill_enabled
:
# Allocate static buffers for CUDAGraph
# TODO(sarckk): Extract this functionality to interface
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
device
=
next
(
self
.
parameters
()).
device
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
hidden_size
,
self
.
config
.
altup_num_inputs
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
self
.
per_layer_inputs
=
torch
.
zeros
(
(
max_num_tokens
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
hidden_size_per_layer_input
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
@
property
def
embed_tokens
(
self
):
return
self
.
self_decoder
.
embed_tokens
def
get_per_layer_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
self_decoder
.
get_per_layer_input_embeddings
(
input_ids
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
self_decoder
.
get_input_embeddings
(
input_ids
)
def
fast_prefill_forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
logits_indices_padded
,
num_logits_indices
=
None
,
None
attn_metadata
=
get_forward_context
().
attn_metadata
# attn_metadata is None during dummy runs
if
(
self
.
fast_prefill_enabled
and
attn_metadata
is
not
None
):
assert
isinstance
(
attn_metadata
,
dict
)
# Last layer is a KV sharing layer
layer_attn_metadata
=
attn_metadata
[
self
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
if
(
isinstance
(
layer_attn_metadata
,
KVSharingFastPrefillMetadata
)):
logits_indices_padded
=
(
layer_attn_metadata
.
logits_indices_padded
)
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
# Copy inputs for cudagraph
batch_size
=
positions
.
size
(
0
)
self
.
positions
[:
batch_size
].
copy_
(
positions
)
self_decoder_hidden_states
,
per_layer_inputs_adjusted
=
\
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
batch_size
],
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
if
logits_indices_padded
is
None
:
logits_indices_padded
=
torch
.
arange
(
positions
.
size
(
0
),
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
# NOTE(sarckk): There is currently a bug caused by
# vLLM converting output of last piecewise CUDA graph
# to weakref, causing memory to be prematurely freed
# when there are multiple compilation units
# Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states
=
self_decoder_hidden_states
.
clone
()
# Copy inputs for cudagraph
num_padded_logits_indices
=
logits_indices_padded
.
size
(
0
)
self
.
positions
[:
num_padded_logits_indices
].
copy_
(
positions
[
logits_indices_padded
])
self
.
hidden_states
[:
num_padded_logits_indices
].
copy_
(
self_decoder_hidden_states
[
logits_indices_padded
])
self
.
per_layer_inputs
[:
num_padded_logits_indices
].
copy_
(
per_layer_inputs_adjusted
[
logits_indices_padded
])
cross_decoder_hidden_states
=
self
.
cross_decoder
(
positions
=
self
.
positions
[:
num_padded_logits_indices
],
hidden_states
=
self
.
hidden_states
[:
num_padded_logits_indices
],
per_layer_inputs
=
self
.
per_layer_inputs
[:
num_padded_logits_indices
],
**
kwargs
,
)
if
num_logits_indices
is
not
None
:
assert
num_logits_indices
>
0
# Merge cross-decoder and self-decoder hidden states
hidden_states
[
logits_indices_padded
[:
num_logits_indices
]]
=
(
cross_decoder_hidden_states
[:
num_logits_indices
])
else
:
hidden_states
=
cross_decoder_hidden_states
return
hidden_states
def
normal_forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
hidden_states
,
per_layer_inputs
=
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
cross_decoder
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
return
hidden_states
def
altup_unembed
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Altup unembed.
target_magnitude
=
torch
.
mean
(
hidden_states
[
0
]
**
2
,
target_magnitude
=
torch
.
mean
(
hidden_states
[
...,
0
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
for
i
in
range
(
1
,
self
.
config
.
altup_num_inputs
):
hidden_states
[
i
]
=
self
.
altup_unembed_projections
[
i
-
1
](
hidden_states
[
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
i
]
**
2
,
hidden_states
[
...,
i
]
=
self
.
altup_unembed_projections
[
i
-
1
](
hidden_states
[
...,
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
...,
i
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
hidden_states
=
torch
.
mean
(
hidden_states
,
dim
=
0
)
hidden_states
[...,
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
EPS
)
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
hidden_states
=
torch
.
mean
(
hidden_states
,
dim
=-
1
)
return
hidden_states
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
self
.
fast_prefill_enabled
:
hidden_states
=
self
.
fast_prefill_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
else
:
hidden_states
=
self
.
normal_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
altup_unembed
(
hidden_states
)
return
self
.
norm
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
@@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
# decoder layer weights, altup_unembed_projections and rmsnorm
# are initialized in text model, others are in self decoder
if
(
not
name
.
startswith
(
'layers'
)
and
not
name
.
startswith
(
'altup_unembed_projections'
)
and
not
name
.
startswith
(
'norm'
)):
name
=
f
"self_decoder.
{
name
}
"
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache scales for compressed-tensors quantization
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment