Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bdde2375
Unverified
Commit
bdde2375
authored
Apr 15, 2025
by
JieXin Liang
Committed by
GitHub
Apr 14, 2025
Browse files
[perf] experimental enhance fp8 per-tensor quant (#5370)
parent
e9fc2ac7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
178 additions
and
13 deletions
+178
-13
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+100
-0
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+12
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+9
-9
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+57
-0
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
bdde2375
...
@@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul(
...
@@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul(
)
)
return
C
return
C
@
triton
.
jit
def
_per_tensor_quant_mla_fp8_stage1
(
x_ptr
,
x_s_ptr
,
head_size
,
x_stride_h
,
x_stride_s
,
eps
,
fp8_max
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
seq_id
=
tl
.
program_id
(
0
)
head_id
=
tl
.
program_id
(
1
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset
<
head_size
x_ptr
+=
head_id
*
x_stride_h
+
seq_id
*
x_stride_s
x
=
tl
.
load
(
x_ptr
+
offset
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
eps
)
tl
.
atomic_max
(
x_s_ptr
,
_absmax
/
fp8_max
)
@
triton
.
jit
def
_per_tensor_quant_mla_fp8_stage2
(
x_ptr
,
x_s_ptr
,
x_q_ptr
,
num_seq
,
head_size
,
x_stride_h
,
x_stride_s
,
fp8_min
,
fp8_max
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
seq_id
=
tl
.
program_id
(
0
)
head_id
=
tl
.
program_id
(
1
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offset
<
head_size
x_s
=
tl
.
load
(
x_s_ptr
)
x_s_inv
=
1.0
/
x_s
x_ptr
+=
head_id
*
x_stride_h
+
seq_id
*
x_stride_s
x_q_ptr
+=
head_id
*
num_seq
*
head_size
+
seq_id
*
head_size
x
=
tl
.
load
(
x_ptr
+
offset
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
x_q
=
tl
.
clamp
(
x
*
x_s_inv
,
fp8_min
,
fp8_max
).
to
(
x_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
x_q_ptr
+
offset
,
x_q
,
mask
=
mask
)
def
per_tensor_quant_mla_fp8
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-12
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
dtype
)
x_s
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
num_head
,
num_seq
,
head_size
=
x
.
shape
BLOCK_SIZE
=
triton
.
next_power_of_2
(
head_size
)
grid
=
(
num_seq
,
num_head
)
_per_tensor_quant_mla_fp8_stage1
[
grid
](
x
,
x_s
,
head_size
,
x
.
stride
(
0
),
x
.
stride
(
1
),
eps
,
fp8_max
,
BLOCK_SIZE
,
)
_per_tensor_quant_mla_fp8_stage2
[
grid
](
x
,
x_s
,
x_q
,
num_seq
,
head_size
,
x
.
stride
(
0
),
x
.
stride
(
1
),
-
fp8_max
,
fp8_max
,
BLOCK_SIZE
,
)
return
x_q
,
x_s
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
bdde2375
...
@@ -168,13 +168,13 @@ def input_to_float8(
...
@@ -168,13 +168,13 @@ def input_to_float8(
"""This function quantizes input values to float8 values with tensor-wise quantization."""
"""This function quantizes input values to float8 values with tensor-wise quantization."""
finfo
=
torch
.
finfo
(
dtype
)
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
float
().
clamp
(
min
=
1e-12
)
fp8_max
=
finfo
.
max
fp8_max
=
finfo
.
max
if
_is_hip
:
if
_is_hip
:
dtype
=
torch
.
float8_e4m3fnuz
dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
fp8_max
=
224.0
scale
=
fp8_max
/
amax
scale
=
fp8_max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=-
fp8_max
,
max
=
fp8_max
)
x_scl_sat
=
(
x
.
float
()
*
scale
).
clamp
(
min
=-
fp8_max
,
max
=
fp8_max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
...
@@ -213,7 +213,11 @@ def block_quant_to_tensor_quant(
...
@@ -213,7 +213,11 @@ def block_quant_to_tensor_quant(
for
j
in
range
(
n_tiles
):
for
j
in
range
(
n_tiles
):
x_dq_block_tiles
[
j
][
i
][:,
:]
=
x_dq_block_tiles
[
j
][
i
]
*
x_s
[
j
][
i
]
x_dq_block_tiles
[
j
][
i
][:,
:]
=
x_dq_block_tiles
[
j
][
i
]
*
x_s
[
j
][
i
]
x_q_tensor
,
scale
=
input_to_float8
(
x_dq_block
,
dtype
=
x_q_block
.
dtype
)
x_q_tensor
,
scale
=
(
sgl_scaled_fp8_quant
(
x_dq_block
)
if
_is_cuda
else
input_to_float8
(
x_dq_block
,
dtype
=
x_q_block
.
dtype
)
)
return
x_q_tensor
,
scale
return
x_q_tensor
,
scale
...
@@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant(
...
@@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant(
x_s
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x_dq_channel
=
x_q_channel
.
to
(
torch
.
float32
)
*
x_s
x_dq_channel
=
x_q_channel
.
to
(
torch
.
float32
)
*
x_s
x_q_tensor
,
scale
=
input_to_float8
(
x_dq_channel
,
dtype
=
x_q_channel
.
dtype
)
x_q_tensor
,
scale
=
(
sgl_scaled_fp8_quant
(
x_dq_channel
)
if
_is_cuda
else
input_to_float8
(
x_dq_channel
,
dtype
=
x_q_channel
.
dtype
)
)
return
x_q_tensor
,
scale
return
x_q_tensor
,
scale
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
bdde2375
...
@@ -53,10 +53,10 @@ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
...
@@ -53,10 +53,10 @@ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
per_tensor_quant_mla_fp8
from
sglang.srt.layers.quantization.fp8_utils
import
(
from
sglang.srt.layers.quantization.fp8_utils
import
(
block_quant_to_tensor_quant
,
block_quant_to_tensor_quant
,
channel_quant_to_tensor_quant
,
channel_quant_to_tensor_quant
,
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.layers.quantization.int8_utils
import
(
from
sglang.srt.layers.quantization.int8_utils
import
(
...
@@ -817,8 +817,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -817,8 +817,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
)
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
input_to_float
8
(
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp
8
(
q_nope
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
q_nope
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
)
)
q_nope_out
=
bmm_fp8
(
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
@@ -848,8 +848,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -848,8 +848,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
input_to_float
8
(
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp
8
(
attn_output
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
attn_output
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
)
)
attn_bmm_output
=
bmm_fp8
(
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
attn_output_val
,
...
@@ -895,8 +895,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -895,8 +895,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
)
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
input_to_float
8
(
q_nope_val
,
q_nope_scale
=
per_tensor_quant_mla_fp
8
(
q_nope
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
q_nope
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
)
)
q_nope_out
=
bmm_fp8
(
q_nope_out
=
bmm_fp8
(
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
q_nope_val
,
self
.
w_kc
,
q_nope_scale
,
self
.
w_scale
,
torch
.
bfloat16
...
@@ -991,8 +991,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -991,8 +991,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
input_to_float
8
(
attn_output_val
,
attn_output_scale
=
per_tensor_quant_mla_fp
8
(
attn_output
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
attn_output
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
)
)
attn_bmm_output
=
bmm_fp8
(
attn_bmm_output
=
bmm_fp8
(
attn_output_val
,
attn_output_val
,
...
...
python/sglang/test/test_block_fp8.py
View file @
bdde2375
...
@@ -7,10 +7,12 @@ import torch
...
@@ -7,10 +7,12 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_tensor_quant_mla_fp8
,
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
static_quant_fp8
,
static_quant_fp8
,
w8a8_block_fp8_matmul
,
w8a8_block_fp8_matmul
,
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
input_to_float8
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
...
@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
...
@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
self
.
_static_quant_fp8
(
*
params
)
self
.
_static_quant_fp8
(
*
params
)
class
TestPerTensorQuantMlaFP8
(
CustomTestCase
):
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
D
=
[
512
,
4096
,
5120
,
13824
]
LAST_D_EXT
=
[
1024
,
0
]
LAST_D
=
[
512
]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_per_tensor_quant_mla_fp8
(
self
,
num_tokens
,
d
,
last_d_ext
,
last_d
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
(
num_tokens
,
d
//
last_d
,
last_d
+
last_d_ext
),
dtype
=
dtype
,
)
x_sub
,
_
=
x
.
split
([
last_d
,
last_d_ext
],
dim
=-
1
)
with
torch
.
inference_mode
():
ref_out
,
ref_s
=
input_to_float8
(
x_sub
.
transpose
(
0
,
1
))
out
,
out_s
=
per_tensor_quant_mla_fp8
(
x_sub
.
transpose
(
0
,
1
))
self
.
assertTrue
(
out
.
is_contiguous
())
self
.
assertTrue
(
torch
.
allclose
(
out
.
to
(
torch
.
float32
),
ref_out
.
to
(
torch
.
float32
),
rtol
=
0.50
)
)
self
.
assertTrue
(
torch
.
allclose
(
out_s
.
to
(
torch
.
float32
),
ref_s
.
to
(
torch
.
float32
))
)
def
test_per_tensor_quant_mla_fp8
(
self
):
for
params
in
itertools
.
product
(
self
.
NUM_TOKENS
,
self
.
D
,
self
.
LAST_D_EXT
,
self
.
LAST_D
,
self
.
DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
num_tokens
=
params
[
0
],
d
=
params
[
1
],
last_d_ext
=
params
[
2
],
last_d
=
params
[
3
],
dtype
=
params
[
4
],
seed
=
params
[
5
],
):
self
.
_per_tensor_quant_mla_fp8
(
*
params
)
# For test
# For test
def
native_w8a8_block_fp8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
def
native_w8a8_block_fp8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""This function performs matrix multiplication with block-wise quantization using native torch.
"""This function performs matrix multiplication with block-wise quantization using native torch.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment