Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dce303e2
Unverified
Commit
dce303e2
authored
Mar 11, 2025
by
lukec
Committed by
GitHub
Mar 11, 2025
Browse files
linear support deepgemm (#4199)
Co-authored-by:
yinfan98
<
1106310035@qq.com
>
parent
4d27eb9a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
44 deletions
+76
-44
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+36
-28
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+39
-15
test/srt/test_fp8_kernel.py
test/srt/test_fp8_kernel.py
+1
-1
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
dce303e2
...
@@ -29,10 +29,13 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
...
@@ -29,10 +29,13 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
if
_is_cuda
:
if
_is_cuda
:
import
deep_gemm
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_enable_jit_deepgemm
=
int
(
os
.
getenv
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
"0"
))
@
triton
.
jit
@
triton
.
jit
def
_per_token_group_quant_fp8
(
def
_per_token_group_quant_fp8
(
...
@@ -722,34 +725,39 @@ def w8a8_block_fp8_matmul(
...
@@ -722,34 +725,39 @@ def w8a8_block_fp8_matmul(
num_workgroups
=
triton
.
cdiv
(
M
,
config
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
num_workgroups
=
triton
.
cdiv
(
M
,
config
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
config
[
"BLOCK_SIZE_N"
]
N
,
config
[
"BLOCK_SIZE_N"
]
)
)
kernel
=
(
_w8a8_block_fp8_matmul_unrolledx4
if
(
is_hip_
==
True
and
num_workgroups
<=
get_device_core_count
())
else
_w8a8_block_fp8_matmul
)
kernel
[
grid
](
# deepgemm only support bf16
A
,
if
_is_cuda
and
C
.
dtype
==
torch
.
bfloat16
and
_enable_jit_deepgemm
:
B
,
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
C
,
else
:
As
,
kernel
=
(
Bs
,
_w8a8_block_fp8_matmul_unrolledx4
M
,
if
(
is_hip_
==
True
and
num_workgroups
<=
get_device_core_count
())
N
,
else
_w8a8_block_fp8_matmul
K
,
)
block_n
,
block_k
,
kernel
[
grid
](
A
.
stride
(
-
2
),
A
,
A
.
stride
(
-
1
),
B
,
B
.
stride
(
1
),
C
,
B
.
stride
(
0
),
As
,
C
.
stride
(
-
2
),
Bs
,
C
.
stride
(
-
1
),
M
,
As
.
stride
(
-
2
),
N
,
As
.
stride
(
-
1
),
K
,
Bs
.
stride
(
1
),
block_n
,
Bs
.
stride
(
0
),
block_k
,
**
config
,
A
.
stride
(
-
2
),
)
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
return
C
return
C
python/sglang/test/test_block_fp8.py
View file @
dce303e2
import
itertools
import
itertools
import
os
import
unittest
import
unittest
import
torch
import
torch
...
@@ -11,6 +12,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -11,6 +12,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul
,
w8a8_block_fp8_matmul
,
)
)
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
# For test
# For test
def
native_per_token_group_quant_fp8
(
def
native_per_token_group_quant_fp8
(
...
@@ -208,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
...
@@ -208,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
class
TestW8A8BlockFP8Matmul
(
unittest
.
TestCase
):
class
TestW8A8BlockFP8Matmul
(
unittest
.
TestCase
):
OUT_DTYPES
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
M
=
[
1
,
7
,
83
,
512
,
2048
]
if
not
_is_cuda
:
N
=
[
128
,
512
,
1024
,
4096
,
7748
,
13824
]
OUT_DTYPES
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
K
=
[
256
,
4096
,
5120
,
3884
,
13824
]
M
=
[
1
,
7
,
83
,
512
,
2048
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
NKs
=
[
BLOCK_SIZE
=
[[
128
,
128
]]
(
N
,
K
)
SEEDS
=
[
0
]
for
N
in
[
128
,
512
,
1024
,
4096
,
7748
,
13824
]
for
K
in
[
256
,
4096
,
5120
,
3884
,
13824
]
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
else
:
# use practical shape in DeepSeek V3 for test
OUT_DTYPES
=
[
torch
.
bfloat16
]
M
=
[
64
,
128
,
512
,
1024
,
4096
]
NKs
=
[
(
1536
,
7168
),
(
3072
,
1536
),
(
24576
,
7168
),
(
4096
,
512
),
(
7168
,
2048
),
(
4608
,
7168
),
(
512
,
7168
),
(
7168
,
2304
),
(
7168
,
512
),
]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -222,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
...
@@ -222,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
def
_w8a8_block_fp8_matmul
(
self
,
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
def
_w8a8_block_fp8_matmul
(
self
,
M
,
NK
,
block_size
,
out_dtype
,
seed
):
N
,
K
=
NK
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale
=
1e-2
factor_for_scale
=
1e-2
...
@@ -257,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
...
@@ -257,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
def
test_w8a8_block_fp8_matmul
(
self
):
def
test_w8a8_block_fp8_matmul
(
self
):
for
params
in
itertools
.
product
(
for
params
in
itertools
.
product
(
self
.
M
,
self
.
M
,
self
.
N
,
self
.
NKs
,
self
.
K
,
self
.
BLOCK_SIZE
,
self
.
BLOCK_SIZE
,
self
.
OUT_DTYPES
,
self
.
OUT_DTYPES
,
self
.
SEEDS
,
self
.
SEEDS
,
):
):
with
self
.
subTest
(
with
self
.
subTest
(
M
=
params
[
0
],
M
=
params
[
0
],
N
=
params
[
1
],
NKs
=
params
[
1
],
K
=
params
[
2
],
block_size
=
params
[
2
],
block_size
=
params
[
3
],
out_dtype
=
params
[
3
],
out_dtype
=
params
[
4
],
seed
=
params
[
4
],
seed
=
params
[
5
],
):
):
self
.
_w8a8_block_fp8_matmul
(
*
params
)
self
.
_w8a8_block_fp8_matmul
(
*
params
)
...
...
test/srt/test_fp8_kernel.py
View file @
dce303e2
...
@@ -17,7 +17,7 @@ class TestFP8Base(unittest.TestCase):
...
@@ -17,7 +17,7 @@ class TestFP8Base(unittest.TestCase):
cls
.
K
=
512
cls
.
K
=
512
cls
.
group_size
=
128
cls
.
group_size
=
128
cls
.
quant_type
=
torch
.
float8_e4m3fn
cls
.
quant_type
=
torch
.
float8_e4m3fn
cls
.
output_type
=
torch
.
float16
cls
.
output_type
=
torch
.
b
float16
@
staticmethod
@
staticmethod
def
_make_A
(
M
,
K
,
group_size
,
out_dtype
):
def
_make_A
(
M
,
K
,
group_size
,
out_dtype
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment