Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
93f2c0aa
Unverified
Commit
93f2c0aa
authored
Oct 08, 2025
by
Lukas Geiger
Committed by
GitHub
Oct 08, 2025
Browse files
[Models] Improve iteration over layers (#26425)
Signed-off-by:
Lukas Geiger
<
lukas.geiger94@gmail.com
>
parent
4ebc9108
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
23 additions
and
22 deletions
+23
-22
vllm/model_executor/models/apertus.py
vllm/model_executor/models/apertus.py
+4
-1
vllm/model_executor/models/falcon_h1.py
vllm/model_executor/models/falcon_h1.py
+2
-2
vllm/model_executor/models/hunyuan_v1.py
vllm/model_executor/models/hunyuan_v1.py
+5
-6
vllm/model_executor/models/lfm2_moe.py
vllm/model_executor/models/lfm2_moe.py
+2
-1
vllm/model_executor/models/longcat_flash.py
vllm/model_executor/models/longcat_flash.py
+2
-2
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+2
-2
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+3
-4
vllm/model_executor/models/qwen3_vl_moe.py
vllm/model_executor/models/qwen3_vl_moe.py
+3
-4
No files found.
vllm/model_executor/models/apertus.py
View file @
93f2c0aa
...
...
@@ -26,6 +26,7 @@
"""Inference-only Apertus model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
from
itertools
import
islice
from
typing
import
Any
,
Optional
,
Union
import
torch
...
...
@@ -412,7 +413,9 @@ class ApertusModel(nn.Module):
residual
=
intermediate_tensors
[
"residual"
]
aux_hidden_states
=
[]
for
idx
,
layer
in
enumerate
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]):
for
idx
,
layer
in
enumerate
(
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
)
):
if
idx
in
self
.
aux_hidden_state_layers
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
...
...
vllm/model_executor/models/falcon_h1.py
View file @
93f2c0aa
...
...
@@ -3,6 +3,7 @@
"""Inference-only FalconH1 model."""
from
collections.abc
import
Iterable
from
itertools
import
islice
from
typing
import
Optional
import
torch
...
...
@@ -480,8 +481,7 @@ class FalconH1Model(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
vllm/model_executor/models/hunyuan_v1.py
View file @
93f2c0aa
...
...
@@ -26,6 +26,7 @@
import
typing
from
collections.abc
import
Callable
,
Iterable
from
itertools
import
islice
from
typing
import
Any
,
Optional
,
Union
import
regex
as
re
...
...
@@ -672,8 +673,9 @@ class HunYuanModel(nn.Module):
cla_factor
=
_get_cla_factor
(
self
.
config
)
prev_kv_states
=
None
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
i
,
layer
in
enumerate
(
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
)
):
hidden_states
,
residual
,
kv_states
=
layer
(
positions
,
hidden_states
,
...
...
@@ -681,10 +683,7 @@ class HunYuanModel(nn.Module):
prev_kv_states
,
)
if
(
getattr
(
self
.
config
,
"use_cla"
,
False
)
and
(
i
-
self
.
start_layer
)
%
cla_factor
==
0
):
if
getattr
(
self
.
config
,
"use_cla"
,
False
)
and
i
%
cla_factor
==
0
:
prev_kv_states
=
kv_states
else
:
prev_kv_states
=
None
...
...
vllm/model_executor/models/lfm2_moe.py
View file @
93f2c0aa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
itertools
import
islice
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -492,7 +493,7 @@ class Lfm2MoeModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
:
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
)
:
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
vllm/model_executor/models/longcat_flash.py
View file @
93f2c0aa
...
...
@@ -35,6 +35,7 @@
import
typing
from
collections.abc
import
Callable
,
Iterable
from
itertools
import
islice
from
typing
import
Optional
,
Union
import
torch
...
...
@@ -519,8 +520,7 @@ class FlashModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
...
...
vllm/model_executor/models/mamba.py
View file @
93f2c0aa
...
...
@@ -3,6 +3,7 @@
"""PyTorch MAMBA model."""
from
collections.abc
import
Iterable
from
itertools
import
islice
from
typing
import
Optional
import
torch
...
...
@@ -162,8 +163,7 @@ class MambaModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
)
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
93f2c0aa
...
...
@@ -26,6 +26,7 @@
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
itertools
import
islice
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
numpy
as
np
...
...
@@ -1106,11 +1107,9 @@ class Qwen3LLMModel(Qwen3Model):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer_idx
,
layer
in
enumerat
e
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
for
layer_idx
,
layer
in
islic
e
(
enumerate
(
self
.
layers
),
self
.
start_layer
,
self
.
end_layer
):
layer_idx
=
layer_idx
+
self
.
start_layer
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
...
...
vllm/model_executor/models/qwen3_vl_moe.py
View file @
93f2c0aa
...
...
@@ -26,6 +26,7 @@
import
typing
from
collections.abc
import
Iterable
from
itertools
import
islice
from
typing
import
Callable
,
Optional
,
Union
import
torch
...
...
@@ -103,11 +104,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer_idx
,
layer
in
enumerat
e
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
for
layer_idx
,
layer
in
islic
e
(
enumerate
(
self
.
layers
),
self
.
start_layer
,
self
.
end_layer
):
layer_idx
=
layer_idx
+
self
.
start_layer
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment