Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ac3dac54
Unverified
Commit
ac3dac54
authored
Apr 15, 2026
by
Benjamin Chislett
Committed by
GitHub
Apr 15, 2026
Browse files
[Bugfix][Perf] Indexer upcast WK to BF16 for fusion (#38928)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
39ac6404
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
64 deletions
+84
-64
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+14
-11
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+70
-53
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
ac3dac54
...
...
@@ -30,6 +30,7 @@ from .deepseek_v2 import (
DeepseekV2DecoderLayer
,
DeepseekV2MixtureOfExperts
,
DeepseekV2MoE
,
_try_load_fp8_indexer_wk
,
get_spec_layer_idx_from_weight_name
,
)
from
.utils
import
maybe_prefix
...
...
@@ -190,10 +191,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
)
# Set MoE hyperparameters
self
.
set_moe_parameters
()
self
.
is_fp4_ckpt
=
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
)
def
set_moe_parameters
(
self
):
self
.
expert_weights
=
[]
...
...
@@ -248,7 +245,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
(
"fused_qkv_a_proj"
,
"kv_a_proj_with_mqa"
,
1
),
]
if
self
.
is_fp4_ckpt
:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping
=
[
(
"wk_weights_proj"
,
"wk"
,
0
),
...
...
@@ -271,6 +267,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
_pending_wk_fp8
:
dict
=
{}
# FP8 indexer wk dequant buffer
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
@@ -281,6 +278,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
rocm_aiter_moe_shared_expert_enabled
and
(
"mlp.shared_experts"
in
name
)
)
name
=
self
.
_rewrite_spec_layer_name
(
spec_layer
,
name
)
if
_try_load_fp8_indexer_wk
(
name
,
loaded_weight
,
_pending_wk_fp8
,
params_dict
,
loaded_params
):
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
ac3dac54
...
...
@@ -66,6 +66,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
scaled_dequantize
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sparse_attn_indexer
import
(
SparseAttnIndexer
,
...
...
@@ -628,10 +632,6 @@ class Indexer(nn.Module):
self
.
vllm_config
=
vllm_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
is_fp4_ckpt
=
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
)
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self
.
topk_tokens
=
config
.
index_topk
self
.
n_head
=
config
.
index_n_heads
# 64
...
...
@@ -646,13 +646,8 @@ class Indexer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wq_b"
,
)
if
self
.
is_fp4_ckpt
:
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized,
# so we run both with quant_config=None
# wk may be upcasted from the default quant;
# experiments show fusion is always faster unless WK proj is in FP4,
# which is not the case for all known quants.
# FP8 wk weights are upcasted to BF16 during loading to maintain fusion.
self
.
wk_weights_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
self
.
head_dim
,
self
.
n_head
],
...
...
@@ -661,21 +656,6 @@ class Indexer(nn.Module):
disable_tp
=
True
,
prefix
=
f
"
{
prefix
}
.wk_weights_proj"
,
)
else
:
self
.
wk
=
ReplicatedLinear
(
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wk"
,
)
self
.
weights_proj
=
ReplicatedLinear
(
hidden_size
,
self
.
n_head
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.weights_proj"
,
)
self
.
k_norm
=
LayerNorm
(
self
.
head_dim
,
eps
=
1e-6
)
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
...
...
@@ -716,14 +696,10 @@ class Indexer(nn.Module):
q_pe
,
q_nope
=
torch
.
split
(
q
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
)
if
self
.
is_fp4_ckpt
:
# Fused wk + weights_proj: one GEMM, then split
kw
,
_
=
self
.
wk_weights_proj
(
hidden_states
)
k
=
kw
[:,
:
self
.
head_dim
]
weights
=
kw
[:,
self
.
head_dim
:]
else
:
k
,
_
=
self
.
wk
(
hidden_states
)
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
k
=
self
.
k_norm
(
k
)
k_pe
,
k_nope
=
torch
.
split
(
...
...
@@ -761,6 +737,46 @@ class Indexer(nn.Module):
return
self
.
indexer_op
(
hidden_states
,
q_fp8
,
k
,
weights
)
def
_try_load_fp8_indexer_wk
(
name
,
tensor
,
buf
,
params_dict
,
loaded_params
):
"""
We fuse the WK and weights_proj projections, but in some checkpoints WK is stored
in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16.
Upcasting to BF16 during loading enables the fusion. This function loads the FP8 WK
weights and scale, and when both are available, dequantizes to BF16 and stores into
the fused wk_weights_proj.weight parameter.
"""
if
"indexer.wk."
not
in
name
or
"wk_weights"
in
name
:
return
False
# Weight is not an isolated WK weight for the indexer, ignore.
is_weight
=
name
.
endswith
(
".weight"
)
and
tensor
.
dtype
==
torch
.
float8_e4m3fn
is_scale
=
"weight_scale_inv"
in
name
if
not
is_weight
and
not
is_scale
:
return
False
# WK is not in FP8 format, ignore.
# Buffer this tensor (weight or scale) until both have arrived.
layer_prefix
=
name
.
rsplit
(
".wk."
,
1
)[
0
]
# e.g. "model.layers.0.self_attn.indexer"
entry
=
buf
.
setdefault
(
layer_prefix
,
{})
entry
[
"weight"
if
is_weight
else
"scale"
]
=
tensor
if
"weight"
not
in
entry
or
"scale"
not
in
entry
:
return
True
# still waiting for the other param
# We have both weight and scale: dequantize FP8 to BF16.
weight_fp8
,
scale_inv
=
entry
[
"weight"
],
entry
[
"scale"
]
del
buf
[
layer_prefix
]
block_size
=
weight_fp8
.
shape
[
1
]
//
scale_inv
.
shape
[
1
]
weight_bf16
=
scaled_dequantize
(
weight_fp8
,
scale_inv
,
group_shape
=
GroupShape
(
block_size
,
block_size
),
out_dtype
=
torch
.
bfloat16
,
)
# Load the dequantized weight into shard 0 of the fused buffer.
fused_name
=
f
"
{
layer_prefix
}
.wk_weights_proj.weight"
param
=
params_dict
[
fused_name
]
param
.
weight_loader
(
param
,
weight_bf16
,
0
)
loaded_params
.
add
(
fused_name
)
return
True
def
_min_latency_fused_qkv_a_proj_impl
(
input_
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
@@ -1344,10 +1360,6 @@ class DeepseekV2ForCausalLM(
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
is_fp4_ckpt
=
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
)
qk_nope_head_dim
=
getattr
(
config
,
"qk_nope_head_dim"
,
0
)
qk_rope_head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
0
)
...
...
@@ -1473,8 +1485,8 @@ class DeepseekV2ForCausalLM(
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
is_fp4_ckpt
:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
_pending_wk_fp8
:
dict
=
{}
# When WK is in FP8, we dequant to BF16 for fusion
indexer_fused_mapping
=
[
(
"wk_weights_proj"
,
"wk"
,
0
),
(
"wk_weights_proj"
,
"weights_proj"
,
1
),
...
...
@@ -1516,6 +1528,11 @@ class DeepseekV2ForCausalLM(
rocm_aiter_moe_shared_expert_enabled
and
(
"mlp.shared_experts"
in
name
)
)
if
_try_load_fp8_indexer_wk
(
name
,
loaded_weight
,
_pending_wk_fp8
,
params_dict
,
loaded_params
):
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment