Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
12eb02e9
Unverified
Commit
12eb02e9
authored
Oct 19, 2025
by
fzyzcjy
Committed by
GitHub
Oct 19, 2025
Browse files
Change bf16 to fp8 for some gemms in attention for DeepSeek ckpt v2 (#11805)
parent
002d0373
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
12 deletions
+107
-12
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+41
-11
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+66
-1
No files found.
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
12eb02e9
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
from
sglang.srt.utils
import
is_sm100_supported
,
offloader
from
sglang.srt.utils
import
ceil_div
,
is_sm100_supported
,
offloader
try
:
try
:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -441,23 +441,53 @@ def _requant_weight_ue8m0(
...
@@ -441,23 +441,53 @@ def _requant_weight_ue8m0(
torch
.
bfloat16
,
torch
.
bfloat16
,
)
)
out_w
,
out_s
=
quant_weight_ue8m0
(
weight_dequant
=
weight_dequant
,
weight_block_size
=
weight_block_size
,
)
out_s
=
_transform_scale_ue8m0
(
out_s
,
mn
=
out_w
.
shape
[
-
2
])
return
out_w
,
out_s
def
quant_weight_ue8m0
(
weight_dequant
:
torch
.
Tensor
,
weight_block_size
:
List
[
int
],
):
assert
weight_block_size
==
[
128
,
128
]
assert
(
weight_dequant
.
dtype
==
torch
.
bfloat16
),
f
"
{
weight_dequant
.
dtype
=
}
{
weight_dequant
.
shape
=
}
"
*
batch_dims
,
n
,
k
=
weight_dequant
.
shape
weight_dequant_flat
=
weight_dequant
.
view
((
-
1
,
k
))
weight_dequant_flat
=
weight_dequant
.
view
((
-
1
,
k
))
out_w_flat
,
out_s_flat
=
per_block_cast_to_fp8
(
weight_dequant_flat
)
out_w_flat
,
out_s_flat
=
per_block_cast_to_fp8
(
weight_dequant_flat
)
out_w
=
out_w_flat
.
view
(
weight
.
shape
)
out_w
=
out_w_flat
.
view
((
*
batch_dims
,
n
,
k
))
out_s
=
out_s_flat
.
view
(
weight_scale_inv
.
shape
)
out_s
=
out_s_flat
.
view
(
(
*
batch_dims
,
ceil_div
(
n
,
weight_block_size
[
0
]),
ceil_div
(
k
,
weight_block_size
[
1
]),
)
)
return
out_w
,
out_s
# NOTE copy and modified from DeepGEMM
def
transform_scale_ue8m0_inplace
(
param
,
mn
):
def
_transform_scale
(
sf
,
mn
:
int
):
param
.
data
=
_transform_scale_ue8m0
(
param
.
data
,
mn
=
mn
)
import
deep_gemm.utils.layout
sf
=
sf
.
index_select
(
-
2
,
torch
.
arange
(
mn
,
device
=
sf
.
device
)
//
128
)
sf
=
deep_gemm
.
utils
.
layout
.
get_mn_major_tma_aligned_packed_ue8m0_tensor
(
sf
)
return
sf
out_s
=
_transform_scale
(
out_s
,
mn
=
out_w
.
shape
[
-
2
])
# NOTE copy and modified from DeepGEMM
def
_transform_scale_ue8m0
(
sf
,
mn
):
import
deep_gemm.utils.layout
return
out_w
,
out_s
sf
=
sf
.
index_select
(
-
2
,
torch
.
arange
(
mn
,
device
=
sf
.
device
)
//
128
)
sf
=
deep_gemm
.
utils
.
layout
.
get_mn_major_tma_aligned_packed_ue8m0_tensor
(
sf
)
return
sf
# COPIED FROM DeepGEMM
# COPIED FROM DeepGEMM
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
12eb02e9
...
@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -94,7 +94,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
channel_quant_to_tensor_quant
,
channel_quant_to_tensor_quant
,
input_to_float8
,
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
quant_weight_ue8m0
,
requant_weight_ue8m0_inplace
,
requant_weight_ue8m0_inplace
,
transform_scale_ue8m0_inplace
,
)
)
from
sglang.srt.layers.quantization.int8_utils
import
(
from
sglang.srt.layers.quantization.int8_utils
import
(
block_dequant
as
int8_block_dequant
,
block_dequant
as
int8_block_dequant
,
...
@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1098,7 +1100,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_lora_rank
,
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
self
.
_get_q_b_proj_quant_config
(
quant_config
)
,
prefix
=
add_prefix
(
"q_b_proj"
,
prefix
),
prefix
=
add_prefix
(
"q_b_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
tp_size
=
attn_tp_size
,
...
@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -2393,6 +2395,17 @@ class DeepseekV2AttentionMLA(nn.Module):
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
@
staticmethod
def
_get_q_b_proj_quant_config
(
quant_config
):
if
get_bool_env_var
(
"SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"
):
# refer to real DeepSeek V3 quant config
return
Fp8Config
(
is_checkpoint_fp8_serialized
=
True
,
weight_block_size
=
[
128
,
128
],
)
else
:
return
quant_config
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
...
@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3130,6 +3143,13 @@ class DeepseekV2ForCausalLM(nn.Module):
):
):
self
.
_weight_requant_ue8m0
(
is_nextn
)
self
.
_weight_requant_ue8m0
(
is_nextn
)
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
if
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
and
get_bool_env_var
(
"SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"
)
):
self
.
_transform_scale_ue8m0
(
is_nextn
)
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
self
.
_transform_scale_nextn_moe_ue8m0
()
self
.
_transform_scale_nextn_moe_ue8m0
()
...
@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3198,6 +3218,25 @@ class DeepseekV2ForCausalLM(nn.Module):
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
)
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
def
_transform_scale_ue8m0
(
self
,
is_nextn
=
False
):
num_hidden_layers
=
1
if
is_nextn
else
self
.
config
.
num_hidden_layers
for
layer_id
in
range
(
num_hidden_layers
):
if
is_nextn
:
layer
=
self
.
model
.
decoder
else
:
layer
=
self
.
model
.
layers
[
layer_id
]
module_list
=
[]
if
self
.
config
.
q_lora_rank
is
not
None
:
module_list
.
append
(
layer
.
self_attn
.
q_b_proj
)
for
module
in
module_list
:
transform_scale_ue8m0_inplace
(
module
.
weight_scale_inv
,
mn
=
module
.
weight
.
shape
[
-
2
]
)
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
def
_transform_scale_nextn_moe_ue8m0
(
self
):
def
_transform_scale_nextn_moe_ue8m0
(
self
):
layer
=
self
.
model
.
decoder
layer
=
self
.
model
.
decoder
...
@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3235,6 +3274,8 @@ class DeepseekV2ForCausalLM(nn.Module):
else
:
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
if
get_bool_env_var
(
"SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"
):
weights
=
self
.
_quant_attn_to_fp8_ue8m0
(
weights
,
is_nextn
=
is_nextn
)
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
if
is_nextn
and
enable_nextn_moe_bf16_cast_to_fp8
(
self
.
quant_config
):
weights
=
self
.
_quant_nextn_moe_to_fp8_ue8m0
(
weights
=
self
.
_quant_nextn_moe_to_fp8_ue8m0
(
weights
,
nextn_layer_id
=
nextn_layer_id
weights
,
nextn_layer_id
=
nextn_layer_id
...
@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3469,6 +3510,30 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
def
_quant_attn_to_fp8_ue8m0
(
self
,
weights
,
is_nextn
):
weights_dict
=
dict
(
weights
)
# temporarily only support DeepSeek V3/R1
weight_block_size
=
[
128
,
128
]
for
layer_id
in
trange
(
self
.
config
.
num_hidden_layers
+
int
(
is_nextn
),
desc
=
"quant attn to fp8 ue8m0"
,
):
for
stem
in
[
# may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
"q_b_proj"
,
]:
partial_name
=
f
"model.layers.
{
layer_id
}
.self_attn.
{
stem
}
"
original_weight
=
weights_dict
[
f
"
{
partial_name
}
.weight"
]
out_w
,
out_s
=
quant_weight_ue8m0
(
original_weight
,
weight_block_size
=
weight_block_size
)
weights_dict
[
f
"
{
partial_name
}
.weight"
]
=
out_w
weights_dict
[
f
"
{
partial_name
}
.weight_scale_inv"
]
=
out_s
return
list
(
weights_dict
.
items
())
# TODO avoid code dup
# TODO avoid code dup
def
_quant_nextn_moe_to_fp8_ue8m0
(
self
,
weights
,
nextn_layer_id
:
int
):
def
_quant_nextn_moe_to_fp8_ue8m0
(
self
,
weights
,
nextn_layer_id
:
int
):
weights_dict
=
dict
(
weights
)
weights_dict
=
dict
(
weights
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment