Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
8af84912
"src/vscode:/vscode.git/clone" did not exist on "c9af4dece0632d688c45e31581e2d389491a8f6b"
Unverified
Commit
8af84912
authored
Oct 18, 2025
by
fzyzcjy
Committed by
GitHub
Oct 18, 2025
Browse files
Support casting bf16 NextN moe to fp8 (#11613)
parent
505329ca
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
3 deletions
+93
-3
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+17
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+76
-2
No files found.
python/sglang/srt/models/deepseek_nextn.py
View file @
8af84912
...
@@ -25,13 +25,18 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
...
@@ -25,13 +25,18 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization
import
Fp8Config
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.models.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
,
enable_nextn_moe_bf16_cast_to_fp8
,
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
...
@@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
...
@@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
enable_nextn_moe_bf16_cast_to_fp8
(
quant_config
):
# refer to real DeepSeek V3 quant config
moe_quant_config
=
Fp8Config
(
is_checkpoint_fp8_serialized
=
True
,
weight_block_size
=
[
128
,
128
],
)
else
:
moe_quant_config
=
None
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
:
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
:
logger
.
warning
(
logger
.
warning
(
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
...
@@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module):
...
@@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module):
config
,
config
,
0
,
0
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_quant_config
=
moe_quant_config
,
is_nextn
=
True
,
is_nextn
=
True
,
prefix
=
add_prefix
(
"decoder"
,
prefix
),
prefix
=
add_prefix
(
"decoder"
,
prefix
),
alt_stream
=
self
.
alt_stream
,
alt_stream
=
self
.
alt_stream
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
8af84912
...
@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
...
@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
tqdm
,
trange
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt
import
single_batch_overlap
from
sglang.srt
import
single_batch_overlap
...
@@ -82,7 +83,7 @@ from sglang.srt.layers.moe import (
...
@@ -82,7 +83,7 @@ from sglang.srt.layers.moe import (
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
Fp8Config
,
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
is_fp8_fnuz
,
...
@@ -196,6 +197,15 @@ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
...
@@ -196,6 +197,15 @@ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
enable_nextn_moe_bf16_cast_to_fp8
(
quant_config
):
return
(
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
and
get_moe_a2a_backend
().
is_deepep
()
)
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
=
[
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
=
[
"fa3"
,
"fa3"
,
"nsa"
,
"nsa"
,
...
@@ -526,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -526,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
self
.
alt_stream
=
alt_stream
self
.
is_nextn
=
is_nextn
if
self
.
tp_size
>
config
.
n_routed_experts
:
if
self
.
tp_size
>
config
.
n_routed_experts
:
raise
ValueError
(
raise
ValueError
(
...
@@ -2381,6 +2392,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -2381,6 +2392,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_id
:
int
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
moe_quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
is_nextn
:
bool
=
False
,
is_nextn
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
...
@@ -2430,7 +2442,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -2430,7 +2442,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if
self
.
is_layer_sparse
:
if
self
.
is_layer_sparse
:
self
.
mlp
=
DeepseekV2MoE
(
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
moe_quant_config
or
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
layer_id
=
self
.
layer_id
,
layer_id
=
self
.
layer_id
,
alt_stream
=
alt_stream
,
alt_stream
=
alt_stream
,
...
@@ -3109,6 +3121,9 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3109,6 +3121,9 @@ class DeepseekV2ForCausalLM(nn.Module):
):
):
self
.
_weight_requant_ue8m0
(
is_nextn
)
self
.
_weight_requant_ue8m0
(
is_nextn
)
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
self
.
_transform_scale_nextn_moe_ue8m0
()
def
_weight_requant_ue8m0
(
self
,
is_nextn
=
False
):
def
_weight_requant_ue8m0
(
self
,
is_nextn
=
False
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
weight_block_size
=
self
.
quant_config
.
weight_block_size
...
@@ -3174,6 +3189,28 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3174,6 +3189,28 @@ class DeepseekV2ForCausalLM(nn.Module):
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
)
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
def
_transform_scale_nextn_moe_ue8m0
(
self
):
layer
=
self
.
model
.
decoder
shared_experts
=
getattr
(
layer
.
mlp
,
"shared_experts"
,
None
)
if
shared_experts
is
not
None
:
for
module
in
[
shared_experts
.
gate_up_proj
,
shared_experts
.
down_proj
,
]:
transform_scale_ue8m0_inplace
(
module
.
weight_scale_inv
,
mn
=
module
.
weight
.
shape
[
-
2
]
)
experts
=
layer
.
mlp
.
experts
if
isinstance
(
experts
,
DeepEPMoE
):
for
w
in
[
experts
.
w13_weight_fp8
,
experts
.
w2_weight_fp8
,
]:
transform_scale_ue8m0_inplace
(
w
[
1
],
mn
=
w
[
0
].
shape
[
-
2
])
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
is_nextn
:
...
@@ -3189,6 +3226,11 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3189,6 +3226,11 @@ class DeepseekV2ForCausalLM(nn.Module):
else
:
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
weights
=
self
.
_quant_nextn_moe_to_fp8_ue8m0
(
weights
,
nextn_layer_id
=
nextn_layer_id
)
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
...
@@ -3418,6 +3460,38 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3418,6 +3460,38 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
# TODO avoid code dup
def
_quant_nextn_moe_to_fp8_ue8m0
(
self
,
weights
,
nextn_layer_id
:
int
):
weights_dict
=
dict
(
weights
)
# temporarily only support DeepSeek V3/R1
weight_block_size
=
[
128
,
128
]
for
layer_id
in
[
nextn_layer_id
]:
for
expert_sub_name
in
[
"shared_experts"
,
*
[
f
"experts.
{
expert_id
}
"
for
expert_id
in
range
(
self
.
config
.
n_routed_experts
)
],
]:
for
stem
in
[
"gate_proj"
,
"up_proj"
,
"down_proj"
,
]:
partial_name
=
(
f
"model.layers.
{
layer_id
}
.mlp.
{
expert_sub_name
}
.
{
stem
}
"
)
original_weight
=
weights_dict
[
f
"
{
partial_name
}
.weight"
]
out_w
,
out_s
=
quant_weight_ue8m0
(
original_weight
,
weight_block_size
=
weight_block_size
)
weights_dict
[
f
"
{
partial_name
}
.weight"
]
=
out_w
weights_dict
[
f
"
{
partial_name
}
.weight_scale_inv"
]
=
out_s
return
list
(
weights_dict
.
items
())
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment