Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2161efe9
Unverified
Commit
2161efe9
authored
Oct 06, 2025
by
Benjamin Chislett
Committed by
GitHub
Oct 06, 2025
Browse files
[Bugfix] Allow skipping MoE in NVFP4 (fix for MTP) (#25987)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
f23b4c04
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
5 deletions
+18
-5
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-0
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+4
-1
vllm/model_executor/models/deepseek_eagle.py
vllm/model_executor/models/deepseek_eagle.py
+1
-0
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+7
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+4
-2
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
2161efe9
...
...
@@ -1194,6 +1194,8 @@ class FusedMoE(CustomOp):
if
quant_config
is
None
else
quant_config
.
get_quant_method
(
self
,
prefix
)
)
if
quant_method
is
None
:
quant_method
=
UnquantizedFusedMoEMethod
(
moe
)
assert
quant_method
is
not
None
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
2161efe9
...
...
@@ -884,8 +884,9 @@ class ModelOptNvFp4Config(QuantizationConfig):
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
skip_layer
=
self
.
is_layer_excluded
(
prefix
)
if
isinstance
(
layer
,
LinearBase
):
if
s
elf
.
is_layer_excluded
(
prefix
)
:
if
s
kip_layer
:
return
UnquantizedLinearMethod
()
# Check if this is a vision model layer that should not be quantized
if
"vision_tower"
in
prefix
or
"vision_model"
in
prefix
:
...
...
@@ -894,6 +895,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
if
skip_layer
:
return
None
return
ModelOptNvFp4FusedMoE
(
self
,
layer
.
moe_config
,
layer
)
return
None
...
...
vllm/model_executor/models/deepseek_eagle.py
View file @
2161efe9
...
...
@@ -55,6 +55,7 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer
(
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
i
+
start_layer_id
}
"
),
config
=
self
.
config
,
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
)
]
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
2161efe9
...
...
@@ -48,7 +48,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
self
.
config
=
config
quant_config
=
vllm_config
.
quant_config
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -66,11 +67,15 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
)
else
:
topk_indices_buffer
=
None
self
.
shared_head
=
SharedHead
(
config
=
config
,
prefix
=
prefix
,
quant_config
=
quant_config
)
self
.
mtp_block
=
DeepseekV2DecoderLayer
(
vllm_config
,
prefix
,
topk_indices_buffer
vllm_config
,
prefix
,
config
=
self
.
config
,
topk_indices_buffer
=
topk_indices_buffer
,
)
def
forward
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
2161efe9
...
...
@@ -1055,11 +1055,13 @@ class DeepseekV2DecoderLayer(nn.Module):
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
,
config
:
Optional
[
DeepseekV2Config
]
=
None
,
topk_indices_buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
if
config
is
None
:
config
=
vllm_config
.
model_config
.
hf_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
...
...
@@ -1200,7 +1202,7 @@ class DeepseekV2Model(nn.Module):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
DeepseekV2DecoderLayer
(
vllm_config
,
prefix
,
topk_indices_buffer
vllm_config
,
prefix
,
topk_indices_buffer
=
topk_indices_buffer
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment