Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8af84912
Unverified
Commit
8af84912
authored
Oct 18, 2025
by
fzyzcjy
Committed by
GitHub
Oct 18, 2025
Browse files
Support casting bf16 NextN moe to fp8 (#11613)
parent
505329ca
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
3 deletions
+93
-3
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+17
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+76
-2
No files found.
python/sglang/srt/models/deepseek_nextn.py
View file @
8af84912
...
@@ -25,13 +25,18 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
...
@@ -25,13 +25,18 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.dp_attention
import
is_dp_attention_enabled
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization
import
Fp8Config
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.models.deepseek_v2
import
(
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
,
enable_nextn_moe_bf16_cast_to_fp8
,
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
...
@@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
...
@@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
enable_nextn_moe_bf16_cast_to_fp8
(
quant_config
):
# refer to real DeepSeek V3 quant config
moe_quant_config
=
Fp8Config
(
is_checkpoint_fp8_serialized
=
True
,
weight_block_size
=
[
128
,
128
],
)
else
:
moe_quant_config
=
None
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
:
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
:
logger
.
warning
(
logger
.
warning
(
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
...
@@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module):
...
@@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module):
config
,
config
,
0
,
0
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
moe_quant_config
=
moe_quant_config
,
is_nextn
=
True
,
is_nextn
=
True
,
prefix
=
add_prefix
(
"decoder"
,
prefix
),
prefix
=
add_prefix
(
"decoder"
,
prefix
),
alt_stream
=
self
.
alt_stream
,
alt_stream
=
self
.
alt_stream
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
8af84912
...
@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
...
@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
tqdm
,
trange
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt
import
single_batch_overlap
from
sglang.srt
import
single_batch_overlap
...
@@ -82,7 +83,7 @@ from sglang.srt.layers.moe import (
...
@@ -82,7 +83,7 @@ from sglang.srt.layers.moe import (
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
Fp8Config
,
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
is_fp8_fnuz
,
...
@@ -196,6 +197,15 @@ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
...
@@ -196,6 +197,15 @@ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
enable_nextn_moe_bf16_cast_to_fp8
(
quant_config
):
return
(
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
and
get_moe_a2a_backend
().
is_deepep
()
)
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
=
[
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
=
[
"fa3"
,
"fa3"
,
"nsa"
,
"nsa"
,
...
@@ -526,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -526,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
self
.
alt_stream
=
alt_stream
self
.
is_nextn
=
is_nextn
if
self
.
tp_size
>
config
.
n_routed_experts
:
if
self
.
tp_size
>
config
.
n_routed_experts
:
raise
ValueError
(
raise
ValueError
(
...
@@ -2381,6 +2392,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -2381,6 +2392,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_id
:
int
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
moe_quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
is_nextn
:
bool
=
False
,
is_nextn
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
...
@@ -2430,7 +2442,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -2430,7 +2442,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if
self
.
is_layer_sparse
:
if
self
.
is_layer_sparse
:
self
.
mlp
=
DeepseekV2MoE
(
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
moe_quant_config
or
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
layer_id
=
self
.
layer_id
,
layer_id
=
self
.
layer_id
,
alt_stream
=
alt_stream
,
alt_stream
=
alt_stream
,
...
@@ -3109,6 +3121,9 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3109,6 +3121,9 @@ class DeepseekV2ForCausalLM(nn.Module):
):
):
self
.
_weight_requant_ue8m0
(
is_nextn
)
self
.
_weight_requant_ue8m0
(
is_nextn
)
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
self
.
_transform_scale_nextn_moe_ue8m0
()
def
_weight_requant_ue8m0
(
self
,
is_nextn
=
False
):
def
_weight_requant_ue8m0
(
self
,
is_nextn
=
False
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
weight_block_size
=
self
.
quant_config
.
weight_block_size
...
@@ -3174,6 +3189,28 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3174,6 +3189,28 @@ class DeepseekV2ForCausalLM(nn.Module):
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
)
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
def
_transform_scale_nextn_moe_ue8m0
(
self
):
layer
=
self
.
model
.
decoder
shared_experts
=
getattr
(
layer
.
mlp
,
"shared_experts"
,
None
)
if
shared_experts
is
not
None
:
for
module
in
[
shared_experts
.
gate_up_proj
,
shared_experts
.
down_proj
,
]:
transform_scale_ue8m0_inplace
(
module
.
weight_scale_inv
,
mn
=
module
.
weight
.
shape
[
-
2
]
)
experts
=
layer
.
mlp
.
experts
if
isinstance
(
experts
,
DeepEPMoE
):
for
w
in
[
experts
.
w13_weight_fp8
,
experts
.
w2_weight_fp8
,
]:
transform_scale_ue8m0_inplace
(
w
[
1
],
mn
=
w
[
0
].
shape
[
-
2
])
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
is_nextn
:
...
@@ -3189,6 +3226,11 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3189,6 +3226,11 @@ class DeepseekV2ForCausalLM(nn.Module):
else
:
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
weights
=
self
.
_quant_nextn_moe_to_fp8_ue8m0
(
weights
,
nextn_layer_id
=
nextn_layer_id
)
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
...
@@ -3418,6 +3460,38 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3418,6 +3460,38 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
# TODO avoid code dup
def
_quant_nextn_moe_to_fp8_ue8m0
(
self
,
weights
,
nextn_layer_id
:
int
):
weights_dict
=
dict
(
weights
)
# temporarily only support DeepSeek V3/R1
weight_block_size
=
[
128
,
128
]
for
layer_id
in
[
nextn_layer_id
]:
for
expert_sub_name
in
[
"shared_experts"
,
*
[
f
"experts.
{
expert_id
}
"
for
expert_id
in
range
(
self
.
config
.
n_routed_experts
)
],
]:
for
stem
in
[
"gate_proj"
,
"up_proj"
,
"down_proj"
,
]:
partial_name
=
(
f
"model.layers.
{
layer_id
}
.mlp.
{
expert_sub_name
}
.
{
stem
}
"
)
original_weight
=
weights_dict
[
f
"
{
partial_name
}
.weight"
]
out_w
,
out_s
=
quant_weight_ue8m0
(
original_weight
,
weight_block_size
=
weight_block_size
)
weights_dict
[
f
"
{
partial_name
}
.weight"
]
=
out_w
weights_dict
[
f
"
{
partial_name
}
.weight_scale_inv"
]
=
out_s
return
list
(
weights_dict
.
items
())
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment