Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5b1afa78
Unverified
Commit
5b1afa78
authored
Jun 14, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Re-quantize DeepSeek model weights to support DeepGEMM new input format (#7156)
parent
c49c1d92
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
125 additions
and
0 deletions
+125
-0
python/sglang/math_utils.py
python/sglang/math_utils.py
+8
-0
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+61
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+56
-0
No files found.
python/sglang/math_utils.py
0 → 100644
View file @
5b1afa78
# COPIED FROM DeepGEMM
def
align
(
x
:
int
,
y
:
int
)
->
int
:
return
ceil_div
(
x
,
y
)
*
y
# COPIED FROM DeepGEMM
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
5b1afa78
...
@@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Tuple
...
@@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Tuple
import
torch
import
torch
from
sglang.math_utils
import
align
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
...
@@ -390,6 +391,66 @@ def block_quant_dequant(
...
@@ -390,6 +391,66 @@ def block_quant_dequant(
return
(
x_q_block
.
to
(
torch
.
float32
)
*
x_scale_repeat
).
to
(
dtype
)
return
(
x_q_block
.
to
(
torch
.
float32
)
*
x_scale_repeat
).
to
(
dtype
)
def
requant_weight_ue8m0_inplace
(
weight
,
weight_scale_inv
,
weight_block_size
):
assert
isinstance
(
weight
,
torch
.
nn
.
Parameter
)
assert
isinstance
(
weight_scale_inv
,
torch
.
nn
.
Parameter
)
weight
.
data
,
weight_scale_inv
.
data
=
_requant_weight_ue8m0
(
weight
,
weight_scale_inv
,
weight_block_size
)
def
_requant_weight_ue8m0
(
weight
:
torch
.
Tensor
,
weight_scale_inv
:
torch
.
Tensor
,
weight_block_size
:
List
[
int
],
):
assert
weight_block_size
==
[
128
,
128
]
*
_
,
n
,
k
=
weight
.
shape
weight_dequant
=
block_quant_dequant
(
weight
,
weight_scale_inv
,
weight_block_size
,
torch
.
bfloat16
,
)
weight_dequant_flat
=
weight_dequant
.
view
((
-
1
,
k
))
out_w_flat
,
out_s_flat
=
per_block_cast_to_fp8
(
weight_dequant_flat
)
out_w
=
out_w_flat
.
view
(
weight
.
shape
)
out_s
=
out_s_flat
.
view
(
weight_scale_inv
.
shape
)
# NOTE copy and modified from DeepGEMM
def
_transform_scale
(
sf
,
mn
:
int
):
import
deep_gemm.utils.layout
sf
=
sf
.
index_select
(
-
2
,
torch
.
arange
(
mn
,
device
=
sf
.
device
)
//
128
)
sf
=
deep_gemm
.
utils
.
layout
.
get_col_major_tma_aligned_packed_tensor
(
sf
)
return
sf
out_s
=
_transform_scale
(
out_s
,
mn
=
out_w
.
shape
[
-
2
])
return
out_w
,
out_s
# COPIED FROM DeepGEMM
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
align
(
m
,
128
),
align
(
n
,
128
)),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
sf
=
ceil_to_ue8m0
(
x_amax
/
448.0
)
x_scaled
=
(
x_view
*
(
1.0
/
sf
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
sf
.
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
)
)
# COPIED FROM DeepGEMM
# COPIED FROM DeepGEMM
def
ceil_to_ue8m0
(
x
:
torch
.
Tensor
):
def
ceil_to_ue8m0
(
x
:
torch
.
Tensor
):
return
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
x
.
abs
())))
return
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
x
.
abs
())))
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
5b1afa78
...
@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant
,
block_quant_to_tensor_quant
,
channel_quant_to_tensor_quant
,
channel_quant_to_tensor_quant
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
requant_weight_ue8m0_inplace
,
)
)
from
sglang.srt.layers.quantization.int8_utils
import
(
from
sglang.srt.layers.quantization.int8_utils
import
(
block_dequant
as
int8_block_dequant
,
block_dequant
as
int8_block_dequant
,
...
@@ -1935,6 +1936,61 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1935,6 +1936,61 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn
.
w_vc
=
bind_or_assign
(
self_attn
.
w_vc
,
w_vc
.
contiguous
())
self_attn
.
w_vc
=
bind_or_assign
(
self_attn
.
w_vc
,
w_vc
.
contiguous
())
self_attn
.
use_deep_gemm_bmm
=
True
self_attn
.
use_deep_gemm_bmm
=
True
if
False
:
# TODO (pr-chain)
self
.
_weight_requant_ue8m0
()
def
_weight_requant_ue8m0
(
self
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
moe_layers
=
list
(
range
(
self
.
config
.
first_k_dense_replace
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
moe_layer_freq
,
)
)
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
layer
=
self
.
model
.
layers
[
layer_id
]
for
module
in
[
layer
.
self_attn
.
fused_qkv_a_proj_with_mqa
,
layer
.
self_attn
.
q_b_proj
,
layer
.
self_attn
.
kv_b_proj
,
layer
.
self_attn
.
o_proj
,
]:
requant_weight_ue8m0_inplace
(
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
if
layer_id
in
moe_layers
:
shared_experts
=
layer
.
mlp
.
shared_experts
for
module
in
[
shared_experts
.
gate_up_proj
,
shared_experts
.
down_proj
,
]:
requant_weight_ue8m0_inplace
(
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
experts
=
layer
.
mlp
.
experts
if
isinstance
(
experts
,
DeepEPMoE
):
for
w
in
[
experts
.
w13_weight_fp8
,
experts
.
w2_weight_fp8
,
]:
requant_weight_ue8m0_inplace
(
w
[
0
],
w
[
1
],
weight_block_size
)
else
:
mlp
=
layer
.
mlp
assert
isinstance
(
mlp
,
DeepseekV2MLP
)
for
module
in
[
mlp
.
gate_up_proj
,
mlp
.
down_proj
,
]:
requant_weight_ue8m0_inplace
(
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
is_nextn
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment