Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
92ff41ab
Unverified
Commit
92ff41ab
authored
Aug 14, 2025
by
Jee Jee Li
Committed by
GitHub
Aug 14, 2025
Browse files
[Model] Modify the gate implementation of glm4_moe (#22832)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
829b9a62
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
11 deletions
+11
-11
docs/models/supported_models.md
docs/models/supported_models.md
+1
-1
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+10
-10
No files found.
docs/models/supported_models.md
View file @
92ff41ab
...
@@ -615,7 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
...
@@ -615,7 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
`Gemma3nForConditionalGeneration`
| Gemma 3n | T + I + A |
`google/gemma-3n-E2B-it`
,
`google/gemma-3n-E4B-it`
, etc. | | | ✅︎ |
|
`Gemma3nForConditionalGeneration`
| Gemma 3n | T + I + A |
`google/gemma-3n-E2B-it`
,
`google/gemma-3n-E4B-it`
, etc. | | | ✅︎ |
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`zai-org/glm-4v-9b`
,
`zai-org/cogagent-9b-20241220`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`zai-org/glm-4v-9b`
,
`zai-org/cogagent-9b-20241220`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.1V-9B-Thinking`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.1V-9B-Thinking`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4vMoeForConditionalGeneration`
| GLM-4.5V | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.5V`
, etc. | | ✅︎ | ✅︎ |
|
`Glm4vMoeForConditionalGeneration`
| GLM-4.5V | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.5V`
, etc. |
✅︎
| ✅︎ | ✅︎ |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎ |
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
, etc. | ✅︎ | | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
, etc. | ✅︎ | | ✅︎ |
...
...
vllm/model_executor/models/glm4_moe.py
View file @
92ff41ab
...
@@ -41,7 +41,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
...
@@ -41,7 +41,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
@@ -118,14 +117,15 @@ class Glm4MoE(nn.Module):
...
@@ -118,14 +117,15 @@ class Glm4MoE(nn.Module):
if
config
.
hidden_act
!=
"silu"
:
if
config
.
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
# NOTE In the transformers implementation, the gate isn't an nn.Linear,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
# so we cannot use ReplicatedLinear here.
config
.
n_routed_experts
,
# See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260
bias
=
False
,
self
.
gate
=
nn
.
Linear
(
quant_config
=
None
,
config
.
hidden_size
,
params_dtype
=
torch
.
float32
,
config
.
n_routed_experts
,
prefix
=
f
"
{
prefix
}
.gate"
)
bias
=
False
,
dtype
=
torch
.
float32
,
)
self
.
gate
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
gate
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
(
config
.
n_routed_experts
,
dtype
=
torch
.
float32
))
torch
.
empty
(
config
.
n_routed_experts
,
dtype
=
torch
.
float32
))
...
@@ -181,7 +181,7 @@ class Glm4MoE(nn.Module):
...
@@ -181,7 +181,7 @@ class Glm4MoE(nn.Module):
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
router_logits
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment