Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a9082a4d
Unverified
Commit
a9082a4d
authored
Aug 25, 2025
by
Isotr0py
Committed by
GitHub
Aug 25, 2025
Browse files
[Bugfix] Fix Qwen3 MoE GPTQ inference (#23490)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
e0329ed4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
6 deletions
+18
-6
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+18
-6
No files found.
vllm/model_executor/models/qwen3_moe.py
View file @
a9082a4d
...
@@ -45,6 +45,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -45,6 +45,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
...
@@ -146,11 +149,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -146,11 +149,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb
=
self
.
enable_eplb
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
)
num_redundant_experts
=
self
.
n_redundant_experts
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
num_experts
,
config
.
hidden_size
,
bias
=
False
,
config
.
num_experts
,
quant_config
=
quant_config
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.gate"
)
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
prefix
=
f
"
{
prefix
}
.gate"
)
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_config
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
# NOTE: hidden_states can have either 1D or 2D shape.
...
@@ -682,4 +694,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
...
@@ -682,4 +694,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
return
self
.
model
.
get_expert_mapping
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment