Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6fa41e0c
Unverified
Commit
6fa41e0c
authored
Aug 05, 2025
by
Yuxuan Zhang
Committed by
GitHub
Aug 04, 2025
Browse files
self.gate dtype update for GLM-4.5 (#22203)
Signed-off-by:
zRzRzRzRzRzRzR
<
2448370773@qq.com
>
parent
031ca762
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4 additions
and
3 deletions
+4
-3
docs/models/supported_models.md
docs/models/supported_models.md
+1
-1
tests/models/registry.py
tests/models/registry.py
+1
-1
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+2
-1
No files found.
docs/models/supported_models.md
View file @
6fa41e0c
...
@@ -606,7 +606,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
...
@@ -606,7 +606,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`zai-org/glm-4v-9b`
,
`zai-org/cogagent-9b-20241220`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`zai-org/glm-4v-9b`
,
`zai-org/cogagent-9b-20241220`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.1V-9B-Thinking`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.1V-9B-Thinking`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4MoeForCausalLM`
| GLM-4.5 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.5`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4MoeForCausalLM`
| GLM-4.5 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.5`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4v_moeForConditionalGeneration`
| GLM-4.5V | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.5V
-Air
`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4v_moeForConditionalGeneration`
| GLM-4.5V | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.5V`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎ |
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
, etc. | ✅︎ | | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
, etc. | ✅︎ | | ✅︎ |
...
...
tests/models/registry.py
View file @
6fa41e0c
...
@@ -383,7 +383,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -383,7 +383,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
,
trust_remote_code
=
True
,
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]}),
# noqa: E501
"Glm4vForConditionalGeneration"
:
_HfExamplesInfo
(
"zai-org/GLM-4.1V-9B-Thinking"
),
# noqa: E501
"Glm4vForConditionalGeneration"
:
_HfExamplesInfo
(
"zai-org/GLM-4.1V-9B-Thinking"
),
# noqa: E501
"Glm4v_moeForConditionalGeneration"
:
_HfExamplesInfo
(
"zai-org/GLM-4.5V
-Air
"
,
"Glm4v_moeForConditionalGeneration"
:
_HfExamplesInfo
(
"zai-org/GLM-4.5V"
,
is_available_online
=
False
),
# noqa: E501
is_available_online
=
False
),
# noqa: E501
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
,
"H2OVLChatModel"
:
_HfExamplesInfo
(
"h2oai/h2ovl-mississippi-800m"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
...
...
vllm/model_executor/models/glm4_moe.py
View file @
6fa41e0c
...
@@ -123,6 +123,7 @@ class Glm4MoE(nn.Module):
...
@@ -123,6 +123,7 @@ class Glm4MoE(nn.Module):
config
.
n_routed_experts
,
config
.
n_routed_experts
,
bias
=
False
,
bias
=
False
,
quant_config
=
None
,
quant_config
=
None
,
params_dtype
=
torch
.
float32
,
prefix
=
f
"
{
prefix
}
.gate"
)
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
gate
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
gate
.
e_score_correction_bias
=
nn
.
Parameter
(
...
@@ -180,7 +181,7 @@ class Glm4MoE(nn.Module):
...
@@ -180,7 +181,7 @@ class Glm4MoE(nn.Module):
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
)
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment