Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1958bda9
Unverified
Commit
1958bda9
authored
Nov 07, 2025
by
Mengqing Cao
Committed by
GitHub
Nov 07, 2025
Browse files
[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)
Signed-off-by:
MengqingCao
<
cmq0113@163.com
>
parent
7bdb42b2
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
49 additions
and
7 deletions
+49
-7
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+19
-4
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+9
-2
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+2
-0
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+2
-0
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+2
-0
vllm/model_executor/models/zamba2.py
vllm/model_executor/models/zamba2.py
+15
-1
No files found.
vllm/model_executor/models/persimmon.py
View file @
1958bda9
...
...
@@ -62,14 +62,23 @@ from .utils import (
class
PersimmonMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PersimmonConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
self
,
config
:
PersimmonConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense_h_to_4h"
,
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense_4h_to_h"
,
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
...
...
@@ -110,12 +119,14 @@ class PersimmonAttention(nn.Module):
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.query_key_value"
,
)
self
.
dense
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
self
.
is_qk_layernorm
=
config
.
qk_layernorm
...
...
@@ -192,7 +203,11 @@ class PersimmonDecoderLayer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
PersimmonMLP
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
PersimmonMLP
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
...
...
vllm/model_executor/models/phi.py
View file @
1958bda9
...
...
@@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
)
scaling
=
self
.
head_size
**-
0.5
...
...
@@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
class
PhiMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PhiConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
self
,
config
:
PhiConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
...
...
@@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
config
.
hidden_size
,
n_inner
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
fc2
=
RowParallelLinear
(
n_inner
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
...
...
@@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
self
.
self_attn
=
PhiAttention
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
mlp
=
PhiMLP
(
config
,
quant_config
)
self
.
mlp
=
PhiMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
self
,
...
...
vllm/model_executor/models/phimoe.py
View file @
1958bda9
...
...
@@ -343,12 +343,14 @@ class PhiMoEAttention(nn.Module):
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
vllm/model_executor/models/plamo2.py
View file @
1958bda9
...
...
@@ -567,12 +567,14 @@ class Plamo2AttentionMixer(nn.Module):
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rope_theta
=
config
.
rope_theta
if
hasattr
(
config
,
"rope_theta"
)
else
10000
...
...
vllm/model_executor/models/qwen.py
View file @
1958bda9
...
...
@@ -102,12 +102,14 @@ class QWenAttention(nn.Module):
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_attn"
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
...
...
vllm/model_executor/models/zamba2.py
View file @
1958bda9
...
...
@@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
super
().
__init__
()
self
.
A
=
ColumnParallelLinear
(
input_dim
,
rank
,
bias
=
False
,
quant_config
=
quant_config
,
gather_output
=
True
input_dim
,
rank
,
bias
=
False
,
quant_config
=
quant_config
,
gather_output
=
True
,
prefix
=
f
"
{
prefix
}
.A"
,
)
if
isinstance
(
output_dim
,
list
):
...
...
@@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
self
.
total_num_attention_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
attention_hidden_size
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
# Even though in Zamba2 weights are shared between attention layers, KV
...
...
@@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
config
.
adapter_rank
,
self
.
attention_hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear_q_adapter"
,
)
linear_k_adapter
=
Zamba2LoRA
(
self
.
attention_hidden_size
,
config
.
adapter_rank
,
self
.
attention_hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear_k_adapter"
,
)
linear_v_adapter
=
Zamba2LoRA
(
self
.
attention_hidden_size
,
config
.
adapter_rank
,
self
.
attention_hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear_v_adapter"
,
)
else
:
linear_q_adapter
=
nn
.
Identity
()
...
...
@@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
2
*
[
self
.
intermediate_size
],
# 2x for gate and input projections
bias
=
self
.
config
.
add_bias_linear
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
...
...
@@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
self
.
hidden_size
,
bias
=
self
.
config
.
add_bias_linear
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
# Only allow GELU activations
...
...
@@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
bare_block_idx
=
bare_block_idx
,
num_hybrid_layers
=
num_hybrid_layers
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
,
)
# Initialize layer normalizations
...
...
@@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear"
,
)
self
.
mamba_decoder
=
Zamba2MambaDecoderLayer
(
config
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment