Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1e3e5215
"app/vscode:/vscode.git/clone" did not exist on "6746a00af8ae75b57f9439810f6aca888a45d1d5"
Unverified
Commit
1e3e5215
authored
Jan 27, 2025
by
yizhang2077
Committed by
GitHub
Jan 27, 2025
Browse files
add unit test for block wise fp8 (#3156)
parent
fb11a439
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
130 additions
and
0 deletions
+130
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_fp8_kernel.py
test/srt/test_fp8_kernel.py
+129
-0
No files found.
test/srt/run_suite.py
View file @
1e3e5215
...
...
@@ -52,6 +52,7 @@ suites = {
"test_w8a8_quantization.py"
,
"test_session_control.py"
,
"test_fp8_kvcache.py"
,
"test_fp8_kernel.py"
,
],
"nightly"
:
[
"test_nightly_gsm8k_eval.py"
,
...
...
test/srt/test_fp8_kernel.py
0 → 100644
View file @
1e3e5215
import
unittest
import
torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
)
class
TestFP8Base
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
M
=
256
# test non-aligned
cls
.
N
=
1024
+
64
cls
.
K
=
512
cls
.
group_size
=
128
cls
.
quant_type
=
torch
.
float8_e4m3fn
cls
.
output_type
=
torch
.
float16
@
staticmethod
def
_make_A
(
M
,
K
,
group_size
,
out_dtype
):
quant_A
=
torch
.
rand
(
M
,
K
//
group_size
,
group_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# -1 ~ 1
quant_A
=
quant_A
*
2
-
1
# scaling abs max to fmax
finfo
=
torch
.
finfo
(
out_dtype
)
fmax
=
finfo
.
max
scaling
=
fmax
/
quant_A
.
abs
().
amax
(
-
1
,
keepdim
=
True
)
quant_A
*=
scaling
quant_A
=
quant_A
.
to
(
out_dtype
).
to
(
torch
.
float32
)
# create scale and A
scale
=
torch
.
rand
(
M
,
K
//
group_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
scale
/=
fmax
A
=
quant_A
*
scale
[...,
None
]
A
=
A
.
reshape
(
M
,
K
)
quant_A
=
quant_A
.
reshape
(
M
,
K
).
to
(
out_dtype
)
return
A
,
quant_A
,
scale
@
staticmethod
def
_make_B
(
K
,
N
,
group_size
,
out_dtype
):
def
_aligned_size
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
*
b
K_aligned
=
_aligned_size
(
K
,
group_size
)
N_aligned
=
_aligned_size
(
N
,
group_size
)
quant_B
=
torch
.
rand
(
K_aligned
//
group_size
,
group_size
,
N_aligned
//
group_size
,
group_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
quant_B
=
quant_B
*
2
-
1
# scaling abs max to fmax
finfo
=
torch
.
finfo
(
out_dtype
)
fmax
=
finfo
.
max
scaling
=
fmax
/
quant_B
.
abs
().
amax
((
1
,
3
),
keepdim
=
True
)
quant_B
*=
scaling
quant_B
=
quant_B
.
to
(
out_dtype
).
to
(
torch
.
float32
)
scale
=
torch
.
rand
(
K_aligned
//
group_size
,
1
,
N_aligned
//
group_size
,
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
scale
/=
fmax
B
=
quant_B
*
scale
B
=
B
.
reshape
(
K_aligned
,
N_aligned
)[:
K
,
:
N
]
quant_B
=
quant_B
.
reshape
(
K_aligned
,
N_aligned
).
to
(
out_dtype
)[:
K
,
:
N
]
scale
=
scale
.
reshape
(
K_aligned
//
group_size
,
N_aligned
//
group_size
)
return
B
,
quant_B
,
scale
class
TestPerTokenGroupQuantFP8
(
TestFP8Base
):
def
test_per_token_group_quant_fp8
(
self
):
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
9
:
return
A
,
A_quant_gt
,
scale_gt
=
self
.
_make_A
(
M
=
self
.
M
,
K
=
self
.
K
,
group_size
=
self
.
group_size
,
out_dtype
=
self
.
quant_type
)
A_quant
,
scale
=
per_token_group_quant_fp8
(
x
=
A
,
group_size
=
self
.
group_size
,
dtype
=
self
.
quant_type
)
torch
.
testing
.
assert_close
(
scale
,
scale_gt
)
diff
=
(
A_quant
.
to
(
torch
.
float16
)
-
A_quant_gt
.
to
(
torch
.
float16
)).
abs
()
diff_count
=
(
diff
>
1e-5
).
count_nonzero
()
assert
diff_count
/
diff
.
numel
()
<
1e-4
class
TestW8A8BlockFP8Matmul
(
TestFP8Base
):
def
test_w8a8_block_fp8_matmul
(
self
):
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
9
:
return
A
,
A_quant_gt
,
A_scale_gt
=
self
.
_make_A
(
M
=
self
.
M
,
K
=
self
.
K
,
group_size
=
self
.
group_size
,
out_dtype
=
self
.
quant_type
)
B
,
B_quant_gt
,
B_scale_gt
=
self
.
_make_B
(
K
=
self
.
K
,
N
=
self
.
N
,
group_size
=
self
.
group_size
,
out_dtype
=
self
.
quant_type
)
C_gt
=
A
.
to
(
self
.
output_type
)
@
B
.
to
(
self
.
output_type
)
C
=
w8a8_block_fp8_matmul
(
A
=
A_quant_gt
,
B
=
B_quant_gt
.
T
.
contiguous
(),
As
=
A_scale_gt
,
Bs
=
B_scale_gt
.
T
.
contiguous
(),
block_size
=
[
128
,
128
],
output_dtype
=
self
.
output_type
,
)
torch
.
testing
.
assert_close
(
C
,
C_gt
,
atol
=
0.5
,
rtol
=
1e-4
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment