Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c3e1999
Unverified
Commit
8c3e1999
authored
Aug 29, 2025
by
Yong Hoon Shin
Committed by
GitHub
Aug 29, 2025
Browse files
Revert gemma3n fast prefill changes (#23897)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
1c26b422
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
355 deletions
+67
-355
tests/v1/e2e/test_kv_sharing_fast_prefill.py
tests/v1/e2e/test_kv_sharing_fast_prefill.py
+1
-0
vllm/model_executor/models/gemma3n.py
vllm/model_executor/models/gemma3n.py
+65
-354
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+1
-1
No files found.
tests/v1/e2e/test_kv_sharing_fast_prefill.py
View file @
8c3e1999
...
...
@@ -64,6 +64,7 @@ def cleanup(llm: LLM, compilation_config: CompilationConfig):
@
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
skip
(
reason
=
"Disable until Gemma3n supports fast prefill"
)
def
test_kv_sharing_fast_prefill
(
monkeypatch
:
pytest
.
MonkeyPatch
,
enforce_eager
:
bool
,
...
...
vllm/model_executor/models/gemma3n.py
View file @
8c3e1999
...
...
@@ -23,11 +23,9 @@ from torch import nn
from
transformers.models.gemma3n.configuration_gemma3n
import
Gemma3nTextConfig
from
vllm.attention
import
Attention
from
vllm.compilation.backends
import
set_model_tag
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
(
_ACTIVATION_REGISTRY
,
GeluAndMul
,
...
...
@@ -47,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
SupportsQuant
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
...
...
@@ -536,178 +533,7 @@ class Gemma3nDecoderLayer(nn.Module):
return
corrected_predictions
# This enables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nSelfDecoder
(
nn
.
Module
):
"""
Includes altup embedding and self decoder layers
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma3nDecoderLayer
],
layer_idx_start
:
int
,
per_layer_model_projection
:
ColumnParallelLinear
,
embed_scale_per_layer
:
torch
.
Tensor
,
embed_tokens_per_layer
:
VocabParallelEmbedding
,
per_layer_projection_norm
:
RMSNorm
,
per_layer_input_scale
:
torch
.
Tensor
,
altup_projections
:
nn
.
ModuleList
,
eps
:
torch
.
Tensor
,
embed_tokens
:
VocabParallelEmbedding
,
embed_scale
:
torch
.
Tensor
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
self
.
per_layer_model_projection
=
per_layer_model_projection
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
embed_scale_per_layer
=
embed_scale_per_layer
self
.
embed_tokens_per_layer
=
embed_tokens_per_layer
self
.
per_layer_projection_norm
=
per_layer_projection_norm
self
.
per_layer_input_scale
=
per_layer_input_scale
self
.
altup_projections
=
altup_projections
self
.
eps
=
eps
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
embed_scale
def
get_per_layer_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Deal with the fact that vocab_size_per_layer_input < vocab_size
# which causes us to have some out of vocab tokens by setting
# those token ids to 0. This matches the HF implementation.
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
config
.
vocab_size_per_layer_input
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
))
return
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
*
self
.
embed_scale_per_layer
def
get_per_layer_inputs
(
self
,
hidden_states_0
:
torch
.
Tensor
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
per_layer_projection
=
self
.
per_layer_model_projection
(
hidden_states_0
)
per_layer_projection
=
per_layer_projection
.
reshape
(
*
hidden_states_0
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
config
.
hidden_size_per_layer_input
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
not
None
:
# Profiling run does not compute per_layer_inputs
per_layer_inputs
=
per_layer_projection
+
per_layer_inputs
per_layer_inputs
*=
self
.
per_layer_input_scale
else
:
per_layer_inputs
=
per_layer_projection
return
per_layer_inputs
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
def
altup_embed
(
self
,
hidden_states_0
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Altup embed.
hidden_states
=
[
hidden_states_0
]
*
self
.
config
.
altup_num_inputs
target_magnitude
=
torch
.
mean
(
hidden_states_0
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
for
i
in
range
(
1
,
self
.
config
.
altup_num_inputs
):
hidden_states
[
i
]
=
self
.
altup_projections
[
i
-
1
](
hidden_states
[
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
i
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=-
1
)
return
hidden_states
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
inputs_embeds
is
not
None
:
hidden_states_0
=
inputs_embeds
else
:
hidden_states_0
=
self
.
get_input_embeddings
(
input_ids
)
adjusted_per_layer_inputs
=
self
.
get_per_layer_inputs
(
hidden_states_0
,
per_layer_inputs
)
hidden_states
=
self
.
altup_embed
(
hidden_states_0
)
# [altnum_inputs, num_tokens, hidden_size]
hidden_states
=
hidden_states
.
permute
(
2
,
0
,
1
)
for
idx
,
layer
in
enumerate
(
self
.
decoder_layers
):
layer_idx
=
idx
+
self
.
layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_input
=
adjusted_per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states
=
hidden_states
.
permute
(
1
,
2
,
0
)
return
hidden_states
,
adjusted_per_layer_inputs
# This enables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nCrossDecoder
(
nn
.
Module
):
"""
Cross-decoder layers
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma3nDecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
# [altnum_inputs, num_tokens, hidden_size]
hidden_states
=
hidden_states
.
permute
(
2
,
0
,
1
)
for
idx
,
layer
in
enumerate
(
self
.
decoder_layers
):
layer_idx
=
idx
+
self
.
layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_input
=
per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states
=
hidden_states
.
permute
(
1
,
2
,
0
)
return
hidden_states
# This disables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
@
support_torch_compile
class
Gemma3nTextModel
(
nn
.
Module
,
SupportsQuant
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -717,6 +543,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
...
...
@@ -786,211 +613,95 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
lambda
prefix
:
Gemma3nDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
eps
=
torch
.
tensor
(
torch
.
finfo
().
min
)
first_kv_shared_layer_idx
=
(
config
.
num_hidden_layers
-
config
.
num_kv_shared_layers
)
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
with
set_model_tag
(
"self_decoder"
):
self
.
self_decoder
=
Gemma3nSelfDecoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_decoder"
,
decoder_layers
=
self
.
layers
[:
first_kv_shared_layer_idx
],
layer_idx_start
=
0
,
per_layer_model_projection
=
self
.
per_layer_model_projection
,
embed_scale_per_layer
=
self
.
embed_scale_per_layer
,
embed_tokens_per_layer
=
self
.
embed_tokens_per_layer
,
per_layer_projection_norm
=
self
.
per_layer_projection_norm
,
per_layer_input_scale
=
self
.
per_layer_input_scale
,
altup_projections
=
self
.
altup_projections
,
eps
=
self
.
eps
,
embed_tokens
=
self
.
embed_tokens
,
embed_scale
=
self
.
embed_scale
,
)
# Layer idx 20-30 are cross-decoder layers in YOCO
with
set_model_tag
(
"cross_decoder"
):
self
.
cross_decoder
=
Gemma3nCrossDecoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.cross_decoder"
,
decoder_layers
=
self
.
layers
[
first_kv_shared_layer_idx
:],
layer_idx_start
=
first_kv_shared_layer_idx
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
fast_prefill_enabled
=
cache_config
.
kv_sharing_fast_prefill
if
self
.
fast_prefill_enabled
:
# Allocate static buffers for CUDAGraph
# TODO(sarckk): Extract this functionality to interface
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
device
=
next
(
self
.
parameters
()).
device
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
hidden_size
,
self
.
config
.
altup_num_inputs
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
self
.
per_layer_inputs
=
torch
.
zeros
(
(
max_num_tokens
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
hidden_size_per_layer_input
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
self
.
eps
=
torch
.
tensor
(
torch
.
finfo
().
min
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
self_decoder
.
get_input_embeddings
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
def
get_per_layer_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Deal with the fact that vocab_size_per_layer_input < vocab_size
# which causes us to have some out of vocab tokens by setting
# those token ids to 0. This matches the HF implementation.
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
config
.
vocab_size_per_layer_input
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
))
return
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
*
self
.
embed_scale_per_layer
def
fast_prefill_
forward
(
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
Optional
[
torch
.
Tensor
]
,
positions
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
logits_indices_padded
,
num_logits_indices
=
None
,
None
attn_metadata
=
get_forward_context
().
attn_metadata
# attn_metadata is None during dummy runs
if
(
self
.
fast_prefill_enabled
and
attn_metadata
is
not
None
):
assert
isinstance
(
attn_metadata
,
dict
)
# Last layer is a KV sharing layer
layer_attn_metadata
=
attn_metadata
[
self
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
if
(
isinstance
(
layer_attn_metadata
,
KVSharingFastPrefillMetadata
)):
logits_indices_padded
=
(
layer_attn_metadata
.
logits_indices_padded
)
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
# Copy inputs for cudagraph
batch_size
=
positions
.
size
(
0
)
self
.
positions
[:
batch_size
].
copy_
(
positions
)
self_decoder_hidden_states
,
per_layer_inputs_adjusted
=
\
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
batch_size
],
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
if
logits_indices_padded
is
None
:
logits_indices_padded
=
torch
.
arange
(
positions
.
size
(
0
),
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
inputs_embeds
is
not
None
:
hidden_states_0
=
inputs_embeds
else
:
hidden_states_0
=
self
.
get_input_embeddings
(
input_ids
)
# NOTE(sarckk): There is currently a bug caused by
# vLLM converting output of last piecewise CUDA graph
# to weakref, causing memory to be prematurely freed
# when there are multiple compilation units
# Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states
=
self_decoder_hidden_states
.
clone
()
# Copy inputs for cudagraph
num_padded_logits_indices
=
logits_indices_padded
.
size
(
0
)
self
.
positions
[:
num_padded_logits_indices
].
copy_
(
positions
[
logits_indices_padded
])
self
.
hidden_states
[:
num_padded_logits_indices
].
copy_
(
self_decoder_hidden_states
[
logits_indices_padded
])
self
.
per_layer_inputs
[:
num_padded_logits_indices
].
copy_
(
per_layer_inputs_adjusted
[
logits_indices_padded
])
cross_decoder_hidden_states
=
self
.
cross_decoder
(
positions
=
self
.
positions
[:
num_padded_logits_indices
],
hidden_states
=
self
.
hidden_states
[:
num_padded_logits_indices
],
per_layer_inputs
=
self
.
per_layer_inputs
[:
num_padded_logits_indices
],
**
kwargs
,
per_layer_projection
=
self
.
per_layer_model_projection
(
hidden_states_0
)
per_layer_projection
=
per_layer_projection
.
reshape
(
*
hidden_states_0
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
config
.
hidden_size_per_layer_input
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
num_logits_indices
is
not
None
:
assert
num_logits_indices
>
0
# Merge cross-decoder and self-decoder hidden states
hidden_states
[
logits_indices_padded
[:
num_logits_indices
]]
=
(
cross_decoder_hidden_states
[:
num_logits_indices
])
if
per_layer_inputs
is
not
None
:
# Profiling run does not compute per_layer_inputs
per_layer_inputs
=
per_layer_projection
+
per_layer_inputs
per_layer_inputs
*=
self
.
per_layer_input_scale
else
:
hidden_states
=
cross_decoder_hidden_states
per_layer_inputs
=
per_layer_projection
return
hidden_states
# Altup embed.
hidden_states
=
[
hidden_states_0
]
*
self
.
config
.
altup_num_inputs
target_magnitude
=
torch
.
mean
(
hidden_states_0
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
for
i
in
range
(
1
,
self
.
config
.
altup_num_inputs
):
hidden_states
[
i
]
=
self
.
altup_projections
[
i
-
1
](
hidden_states
[
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
i
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=
0
)
def
normal_forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
hidden_states
,
per_layer_inputs
=
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
cross_decoder
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
return
hidden_states
# Transformer blocks.
for
layer_idx
,
layer
in
enumerate
(
self
.
layers
):
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_input
=
per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
def
altup_unembed
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Altup unembed.
target_magnitude
=
torch
.
mean
(
hidden_states
[
...,
0
]
**
2
,
target_magnitude
=
torch
.
mean
(
hidden_states
[
0
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
for
i
in
range
(
1
,
self
.
config
.
altup_num_inputs
):
hidden_states
[
...,
i
]
=
self
.
altup_unembed_projections
[
i
-
1
](
hidden_states
[
...,
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
...,
i
]
**
2
,
hidden_states
[
i
]
=
self
.
altup_unembed_projections
[
i
-
1
](
hidden_states
[
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
i
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
...,
i
]
*=
target_magnitude
/
torch
.
maximum
(
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
hidden_states
=
torch
.
mean
(
hidden_states
,
dim
=-
1
)
return
hidden_states
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
hidden_states
=
torch
.
mean
(
hidden_states
,
dim
=
0
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
self
.
fast_prefill_enabled
:
hidden_states
=
self
.
fast_prefill_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
else
:
hidden_states
=
self
.
normal_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
altup_unembed
(
hidden_states
)
return
self
.
norm
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/gemma3n_mm.py
View file @
8c3e1999
...
...
@@ -620,7 +620,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
if
input_ids
is
not
None
:
per_layer_inputs
=
self
.
language_model
.
model
.
self_decoder
.
get_per_layer_input_embeddings
(
per_layer_inputs
=
self
.
language_model
.
model
.
get_per_layer_input_embeddings
(
input_ids
)
per_layer_inputs
=
per_layer_inputs
.
reshape
(
-
1
,
self
.
config
.
text_config
.
num_hidden_layers
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment