Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c2c4f57f
Unverified
Commit
c2c4f57f
authored
Jun 07, 2025
by
Pavani Majety
Committed by
GitHub
Jun 07, 2025
Browse files
[DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (#6853)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
23881fa6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
386 additions
and
13 deletions
+386
-13
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+19
-1
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+334
-7
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+33
-5
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
c2c4f57f
...
...
@@ -556,7 +556,8 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
if
(
param
.
data
[
expert_id
]
!=
1
"compressed"
in
self
.
quant_method
.
__class__
.
__name__
.
lower
()
and
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
):
raise
ValueError
(
...
...
@@ -580,6 +581,23 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
tp_rank
,
)
return
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
if
"weight_scale_2"
in
weight_name
or
"input_scale"
in
weight_name
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
,
)
elif
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
return
# Case weight scales and zero_points
if
"scale"
in
weight_name
or
"zero"
in
weight_name
:
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
c2c4f57f
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import
logging
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
,
LinearMethodBase
from
sglang.srt.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
...
...
@@ -15,10 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
is_sm100_supported
,
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
convert_to_channelwise
,
is_layer_skipped
,
requantize_with_max_scale
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -270,9 +277,16 @@ class ModelOptFp4Config(QuantizationConfig):
)
is_checkpoint_nvfp4_serialized
=
"NVFP4"
in
quant_method
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
if
not
kv_cache_quant_algo
:
kv_cache_quant_algo
=
"auto"
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
if
not
(
group_size
and
kv_cache_quant_algo
and
exclude_modules
):
logger
.
warning
(
f
"group_size:
{
group_size
}
,"
f
"kv_cache_quant_algo:
{
kv_cache_quant_algo
}
,"
f
"exclude_modules:
{
exclude_modules
}
"
)
raise
ValueError
(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
...
...
@@ -285,19 +299,30 @@ class ModelOptFp4Config(QuantizationConfig):
exclude_modules
,
)
def
is_layer_excluded
(
self
,
prefix
:
str
,
exclude_modules
:
list
):
import
regex
as
re
for
pattern
in
exclude_modules
:
regex_str
=
pattern
.
replace
(
"."
,
r
"\."
).
replace
(
"*"
,
r
".*"
)
if
re
.
fullmatch
(
regex_str
,
prefix
):
return
True
return
False
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
self
.
exclude_modules
and
any
(
module
in
prefix
for
module
in
self
.
exclude_modules
):
return
None
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
)
or
self
.
is_layer_excluded
(
prefix
,
self
.
exclude_modules
):
return
UnquantizedLinearMethod
()
return
ModelOptFp4LinearMethod
(
self
)
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptNvFp4FusedMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
...
@@ -461,3 +486,305 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
class
ModelOptNvFp4FusedMoEMethod
:
"""
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
Args:
quant_config: NVFP4 Quant Config
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoEMethodBase
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
:
ModelOptFp4Config
):
self
.
quant_config
=
quant_config
if
not
is_sm100_supported
():
raise
ValueError
(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
not
self
.
quant_config
.
is_checkpoint_nvfp4_serialized
:
raise
ValueError
(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
layer
.
num_experts
=
num_experts
layer
.
params_dtype
=
params_dtype
layer
.
quant_config
=
self
.
quant_config
weight_dtype
=
torch
.
uint8
weight_scale_dtype
=
torch
.
float8_e4m3fn
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# GEMM 1
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
2
,
dtype
=
weight_dtype
,
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
# GEMM 2
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
2
,
dtype
=
weight_dtype
,
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
w13_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
self
.
quant_config
.
group_size
,
dtype
=
weight_scale_dtype
,
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
,
dtype
=
weight_scale_dtype
,
),
input_dim
=
1
,
output_dim
=
2
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
}
)
w13_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
w2_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight_scale_2"
,
w2_weight_scale_2
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
scale
.
dtype
==
torch
.
float8_e4m3fn
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
)
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# GEMM 1
if
not
torch
.
allclose
(
layer
.
w13_weight_scale_2
[:,
0
],
layer
.
w13_weight_scale_2
[:,
1
]
):
logger
.
warning_once
(
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected."
)
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
layer
.
g1_alphas
=
Parameter
(
(
w13_input_scale
*
w13_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
assert
(
layer
.
w13_weight_scale
.
shape
[
2
]
%
16
==
0
),
"Expected weight_scale.dim(1) to be divisible by 16"
assert
(
layer
.
w13_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
"Weight Blockscale must be represented as FP8-E4M3"
w13_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w13_weight_scale
)
layer
.
w13_blockscale_swizzled
=
Parameter
(
w13_blockscale_swizzled
,
requires_grad
=
False
)
# This is for quantization, so we need to invert it.
layer
.
w13_input_scale_quant
=
Parameter
(
(
1
/
w13_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
# GEMM 2
layer
.
g2_alphas
=
Parameter
(
(
layer
.
w2_input_scale
*
layer
.
w2_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
# This is for quantization, so we need to invert it.
layer
.
w2_input_scale_quant
=
Parameter
(
(
1
/
layer
.
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
assert
(
layer
.
w2_weight_scale
.
shape
[
2
]
%
16
==
0
),
"Expected weight_scale.dim(1) to be divisible by 16"
assert
(
layer
.
w2_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
"Weight Blockscale must be represented as FP8-E4M3"
w2_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
)
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
device
=
layer
.
w13_weight
.
device
layer
.
cutlass_moe_params
=
CutlassMoEParams
(
CutlassMoEType
.
BlockscaledFP4
,
device
,
num_experts
=
layer
.
num_experts
,
intermediate_size_per_partition
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
# n
hidden_size
=
layer
.
w13_weight
.
shape
[
2
]
*
2
,
)
# k
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
return
cutlass_moe_fp4
(
a
=
x
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
w1_fp4
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alphas
=
layer
.
g1_alphas
,
a2_gscale
=
layer
.
w2_input_scale_quant
,
w2_fp4
=
layer
.
w2_weight
,
w2_blockscale
=
layer
.
w2_blockscale_swizzled
,
w2_alphas
=
layer
.
g2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
params
=
layer
.
cutlass_moe_params
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
).
to
(
x
.
dtype
)
python/sglang/srt/models/deepseek_v2.py
View file @
c2c4f57f
...
...
@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module):
global_server_args_dict
[
"disable_shared_experts_fusion"
]
=
False
log_info_on_rank0
(
logger
,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
,
"Deepseek V3/R1 with fp8
/fp4
can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
,
)
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
...
...
@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn
.
use_deep_gemm_bmm
=
True
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
...
...
@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module):
"up_proj.qzeros"
,
"up_proj.scales"
,
]
elif
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"down_proj.weight_scale_2"
,
"down_proj.input_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"gate_proj.weight_scale_2"
,
"gate_proj.input_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
"up_proj.weight_scale_2"
,
"up_proj.input_scale"
,
]
else
:
raise
ValueError
(
f
"Unsupported shared expert fusion for quantization:
{
self
.
quant_config
.
get_name
()
}
."
...
...
@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
...
...
@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module):
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
)
)
param
=
params_dict
[
param_name
]
...
...
@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module):
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
)
and
name
not
in
params_dict
:
# modelopt attn kv scale is named differently
if
any
(
scale
in
name
for
scale
in
[
"k_scale"
,
"v_scale"
]):
name
=
name
.
replace
(
"_proj"
,
"attn_mqa"
)
else
:
logger
.
warning
(
f
"Unknown scale found in checkpoint:
{
name
}
"
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment