Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
339f8eef
Unverified
Commit
339f8eef
authored
Sep 05, 2025
by
fzyzcjy
Committed by
GitHub
Sep 05, 2025
Browse files
[1/2] Optimizations and refactors about quant kernel (#9534)
parent
afd9f2f5
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1002 additions
and
335 deletions
+1002
-335
python/sglang/srt/bench_utils.py
python/sglang/srt/bench_utils.py
+4
-2
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+29
-8
python/sglang/srt/layers/quantization/int8_kernel.py
python/sglang/srt/layers/quantization/int8_kernel.py
+8
-2
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+203
-48
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+3
-8
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+447
-177
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+6
-12
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-2
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+15
-18
sgl-kernel/python/sgl_kernel/test_utils.py
sgl-kernel/python/sgl_kernel/test_utils.py
+125
-0
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+161
-58
No files found.
python/sglang/srt/bench_utils.py
View file @
339f8eef
import
os
import
os
import
re
import
sys
import
sys
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
...
@@ -108,7 +109,8 @@ def bench_kineto(
...
@@ -108,7 +109,8 @@ def bench_kineto(
if
not
with_multiple_kernels
:
if
not
with_multiple_kernels
:
for
name
in
kernel_names
:
for
name
in
kernel_names
:
assert
(
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
sum
([
int
(
re
.
search
(
name
,
line
)
is
not
None
)
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
# Save chrome traces
# Save chrome traces
...
@@ -122,7 +124,7 @@ def bench_kineto(
...
@@ -122,7 +124,7 @@ def bench_kineto(
total_time
=
0
total_time
=
0
total_num
=
0
total_num
=
0
for
line
in
prof_lines
:
for
line
in
prof_lines
:
if
name
in
li
ne
:
if
re
.
search
(
name
,
line
)
is
not
No
ne
:
time_str
=
line
.
split
()[
-
2
]
time_str
=
line
.
split
()[
-
2
]
num_str
=
line
.
split
()[
-
1
]
num_str
=
line
.
split
()[
-
1
]
for
unit
,
scale
in
units
.
items
():
for
unit
,
scale
in
units
.
items
():
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
339f8eef
...
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
...
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
(
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
,
sgl_per_token_quant_fp8
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
# Temporary
sgl_per_token_quant_fp8
,
try
:
)
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit
=
True
except
ImportError
:
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit
=
False
if
_is_hip
:
if
_is_hip
:
if
_use_aiter
:
if
_use_aiter
:
...
@@ -496,6 +502,21 @@ def sglang_per_token_group_quant_fp8(
...
@@ -496,6 +502,21 @@ def sglang_per_token_group_quant_fp8(
)
)
if
x
.
shape
[
0
]
>
0
:
if
x
.
shape
[
0
]
>
0
:
# Temporary
if
enable_sgl_per_token_group_quant_8bit
:
sgl_per_token_group_quant_8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
,
fuse_silu_and_mul
,
masked_m
,
)
else
:
sgl_per_token_group_quant_fp8
(
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
)
...
...
python/sglang/srt/layers/quantization/int8_kernel.py
View file @
339f8eef
...
@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
...
@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_token_group_quant_int8
# Temporary
try
:
from
sgl_kernel
import
sgl_per_token_group_quant_8bit
except
ImportError
:
from
sgl_kernel
import
(
sgl_per_token_group_quant_int8
as
sgl_per_token_group_quant_8bit
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
...
@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
sgl_per_token_group_quant_
int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
sgl_per_token_group_quant_
8bit
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
return
x_q
,
x_s
return
x_q
,
x_s
...
...
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
339f8eef
import
itertools
import
itertools
import
os
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
import
torch
import
triton
import
triton
from
sgl_kernel.test_utils
import
create_per_token_group_quant_test_data
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip
...
@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
mode_concentrated
=
os
.
environ
.
get
(
"SGLANG_BENCH_MODE"
,
""
)
==
"concentrated"
num_tokens_range
=
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
]
if
int
(
os
.
environ
.
get
(
"SGLANG_NSYS_PROFILING"
,
"0"
)):
hidden_dim_range
=
[
1536
,
7168
,
18432
]
# For DeepSeek V3/R1
# configs = [[
group_size_range
=
[
128
]
# For DeepSeek V3/R1
# 768,
# TODO test int8
# 16384,
dst_dtype_range
=
[
fp8_type_
]
# 128,
flags_range
=
[
# None,
# fp8_type_,
# dict(
# column_major_scales=True,
# scale_tma_aligned=True,
# scale_ue8m0=True,
# fuse_silu_and_mul=False,
# masked_layout_mode=None,
# ),
# ]]
configs
=
[
[
768
*
8
,
2048
,
128
,
48
,
fp8_type_
,
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
# masked_layout_mode=None,
masked_layout_mode
=
"balanced"
,
# masked_layout_mode="extreme",
),
]
]
elif
mode_concentrated
:
configs
=
list
(
itertools
.
product
(
[
768
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
],
)
)
+
list
(
itertools
.
product
(
[
768
*
8
],
[
2048
],
[
128
],
[
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
else
:
configs
=
list
(
itertools
.
product
(
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
],
[
1536
,
7168
,
16384
],
[
128
],
[
None
],
[
fp8_type_
],
[
dict
(
dict
(
column_major_scales
=
False
,
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
]
],
)
)
+
list
(
configs
=
list
(
itertools
.
product
(
itertools
.
product
(
num_tokens_range
,
[
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
hidden_dim_range
,
[
2048
],
group_size_range
,
[
128
],
dst_dtype_range
,
[
8
,
16
,
32
,
48
],
flags_range
,
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
],
)
)
)
)
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"dst_dtype"
,
"flags"
],
x_names
=
[
"num_tokens"
,
"hidden_dim"
,
"group_size"
,
"num_ranks"
,
"dst_dtype"
,
"flags"
,
],
x_vals
=
configs
,
x_vals
=
configs
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_vals
=
[
"triton"
,
"sglang"
],
line_names
=
[
"Triton"
,
"SGL Kernel"
],
# Triton has multi kernels and we only report the time for the core one
line_names
=
[
"Triton (Inaccurate)"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"per-token-group-quant-8bit-performance"
,
plot_name
=
"per-token-group-quant-8bit-performance"
,
args
=
{},
args
=
{},
)
)
)
)
def
benchmark
(
num_tokens
,
hidden_dim
,
group_size
,
dst_dtype
,
flags
,
provider
):
def
benchmark
(
if
flags
[
"scale_ue8m0"
]
and
group_size
!=
128
:
num_tokens
,
hidden_dim
,
group_size
,
num_ranks
,
dst_dtype
,
flags
,
provider
return
):
print
(
device
=
torch
.
device
(
"cuda"
)
f
"Testing:
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
{
provider
=
}
"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
x
,
masked_m
=
create_per_token_group_quant_test_data
(
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
fn
,
kernel_names
=
{
fn
,
kernel_names
=
{
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_fp8"
),
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel"
,
),
"sglang"
:
(
"sglang"
:
(
sglang_per_token_group_quant_8bit
,
sglang_per_token_group_quant_8bit
,
"per_token_group_quant_8bit_kernel"
,
"per_token_group_quant_8bit_kernel"
,
),
),
}[
provider
]
}[
provider
]
bench_fn
=
lambda
:
fn
(
x
=
x
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
flags
)
bench_fn
=
lambda
:
fn
(
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]},
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
,
num_tests
=
300
if
mode_concentrated
else
30
)
return
time_s
*
1e6
return
time_s
*
1e6
...
...
sgl-kernel/csrc/common_extension.cc
View file @
339f8eef
...
@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
m
.
impl
(
"fp8_blockwise_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_blockwise_scaled_mm
);
m
.
def
(
m
.
def
(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
"sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"
);
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
impl
(
"sgl_per_token_group_quant_8bit"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_8bit
);
m
.
def
(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_int8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_int8
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
...
...
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
View file @
339f8eef
This diff is collapsed.
Click to expand it.
sgl-kernel/include/sgl_kernel_ops.h
View file @
339f8eef
...
@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm(
const
torch
::
Dtype
&
out_dtype
);
const
torch
::
Dtype
&
out_dtype
);
void
scaled_fp4_quant
(
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
sgl_per_token_group_quant_
fp8
(
void
sgl_per_token_group_quant_
8bit
(
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
at
::
Tensor
output_s
,
int64_t
group_size
,
int64_t
group_size
,
double
eps
,
double
eps
,
double
fp8_min
,
double
min_8bit
,
double
fp8_max
,
double
max_8bit
,
bool
scale_ue8m0
);
bool
scale_ue8m0
,
void
sgl_per_token_group_quant_int8
(
bool
fuse_silu_and_mul
,
at
::
Tensor
input
,
const
std
::
optional
<
torch
::
Tensor
>&
masked_m
);
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
sgl_per_token_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
);
void
bmm_fp8
(
void
bmm_fp8
(
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
339f8eef
...
@@ -55,8 +55,7 @@ from sgl_kernel.gemm import (
...
@@ -55,8 +55,7 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant
,
scaled_fp4_grouped_quant
,
scaled_fp4_quant
,
scaled_fp4_quant
,
sgl_per_tensor_quant_fp8
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_8bit
,
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
shuffle_rows
,
shuffle_rows
,
silu_and_mul_scaled_fp4_grouped_quant
,
silu_and_mul_scaled_fp4_grouped_quant
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
339f8eef
...
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
...
@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return
output
return
output
def
sgl_per_token_group_quant_
fp8
(
def
sgl_per_token_group_quant_
8bit
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
...
@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
...
@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
eps
:
float
,
eps
:
float
,
fp8_min
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
fp8_max
:
float
,
scale_ue8m0
:
bool
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_fp8
.
default
(
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_8bit
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
input
,
)
output_q
,
output_s
,
group_size
,
def
sgl_per_token_group_quant_int8
(
eps
,
input
:
torch
.
Tensor
,
fp8_min
,
output_q
:
torch
.
Tensor
,
fp8_max
,
output_s
:
torch
.
Tensor
,
scale_ue8m0
,
group_size
:
int
,
fuse_silu_and_mul
,
eps
:
float
,
masked_m
,
int8_min
:
float
,
int8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
sgl_per_token_group_quant_int8
.
default
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
)
...
...
sgl-kernel/python/sgl_kernel/test_utils.py
0 → 100644
View file @
339f8eef
import
torch
def
create_per_token_group_quant_test_data
(
num_tokens
,
hidden_dim
,
num_ranks
,
flags
):
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
seed
=
num_tokens
*
10000
+
hidden_dim
gen_cpu
=
torch
.
Generator
(
device
=
"cpu"
)
gen_cpu
.
manual_seed
(
seed
)
gen_cuda
=
torch
.
Generator
(
device
=
"cuda"
)
gen_cuda
.
manual_seed
(
seed
)
if
flags
[
"fuse_silu_and_mul"
]:
effective_hidden_dim
=
hidden_dim
*
2
else
:
effective_hidden_dim
=
hidden_dim
del
hidden_dim
if
(
masked_layout_mode
:
=
flags
[
"masked_layout_mode"
])
is
not
None
:
num_max_dispatch_tokens_per_rank
=
768
num_global_experts
=
288
num_local_experts
,
remainder
=
divmod
(
num_global_experts
,
num_ranks
)
assert
remainder
==
0
# mimic DeepEP low_latency_dispatch output
x
=
torch
.
randn
(
num_local_experts
,
num_max_dispatch_tokens_per_rank
*
num_ranks
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
if
masked_layout_mode
==
"balanced"
:
masked_m
=
_compute_balanced_split
(
num_tokens
,
num_local_experts
)
elif
masked_layout_mode
==
"imbalanced"
:
masked_m
=
_compute_imbalanced_split
(
num_tokens
,
num_local_experts
,
gen_cpu
=
gen_cpu
)
elif
masked_layout_mode
==
"extreme"
:
masked_m
=
torch
.
tensor
(
[
num_tokens
]
+
[
0
]
*
(
num_local_experts
-
1
),
dtype
=
torch
.
int
)
else
:
raise
NotImplementedError
print
(
f
"
{
masked_layout_mode
=
}
{
masked_m
=
}
{
x
.
shape
=
}
"
)
masked_m
=
masked_m
.
to
(
device
)
return
x
,
masked_m
else
:
x
=
torch
.
randn
(
num_tokens
,
effective_hidden_dim
,
device
=
device
,
dtype
=
dtype
,
generator
=
gen_cuda
,
)
x
[
torch
.
randn
(
x
.
shape
,
device
=
device
,
generator
=
gen_cuda
)
<
0.001
]
*=
10
return
x
,
None
def
_compute_balanced_split
(
total
:
int
,
arr_len
:
int
):
base
=
total
//
arr_len
remainder
=
total
%
arr_len
ans
=
[
base
+
1
if
i
<
remainder
else
base
for
i
in
range
(
arr_len
)]
assert
sum
(
ans
)
==
total
return
torch
.
tensor
(
ans
,
dtype
=
torch
.
int
)
def
_compute_imbalanced_split
(
total
:
int
,
arr_len
:
int
,
gen_cpu
,
dtype
=
torch
.
int
)
->
list
[
int
]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw
=
torch
.
rand
(
arr_len
,
generator
=
gen_cpu
)
**
3
noise
=
noise_raw
/
noise_raw
.
sum
()
ans
=
(
noise
*
total
).
round
().
to
(
dtype
)
diff
=
total
-
ans
.
sum
().
item
()
while
diff
!=
0
:
idx
=
torch
.
randint
(
0
,
arr_len
,
(
1
,),
generator
=
gen_cpu
).
item
()
if
diff
>
0
:
ans
[
idx
]
+=
1
diff
-=
1
elif
diff
<
0
and
ans
[
idx
]
>
0
:
ans
[
idx
]
-=
1
diff
+=
1
assert
sum
(
ans
)
==
total
return
ans
def
assert_all_close_or_tiny_diff
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
assert
(
a
.
shape
==
b
.
shape
)
and
(
a
.
dtype
==
b
.
dtype
),
f
"
{
a
.
shape
=
}
{
b
.
shape
=
}
{
a
.
dtype
=
}
{
b
.
dtype
=
}
"
numel
=
a
.
numel
()
if
a
.
dtype
==
torch
.
float8_e4m3fn
:
a_u8
=
a
.
view
(
torch
.
uint8
)
b_u8
=
b
.
view
(
torch
.
uint8
)
diff_u8
=
(
a_u8
.
to
(
torch
.
int16
)
-
b_u8
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a_u8
>=
0
)
&
(
b_u8
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff_u8
==
1
).
sum
().
item
()
count_large_diff
=
(
diff_u8
>=
2
).
sum
().
item
()
elif
a
.
dtype
==
torch
.
int8
:
diff
=
(
a
.
to
(
torch
.
int16
)
-
a
.
to
(
torch
.
int16
)).
abs
()
count_diff_sign
=
((
a
>=
0
)
&
(
b
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff
==
1
).
sum
().
item
()
count_large_diff
=
(
diff
>=
2
).
sum
().
item
()
else
:
raise
NotImplementedError
assert
(
(
count_diff_sign
==
0
)
and
(
count_large_diff
==
0
)
and
(
(
count_tiny_diff
/
numel
<
0.005
)
or
((
count_tiny_diff
/
numel
<
0.04
)
and
(
numel
<=
4096
))
)
),
f
"
{
count_diff_sign
=
}
{
count_tiny_diff
=
}
{
count_large_diff
=
}
{
numel
=
}
{
a
=
}
{
b
=
}
"
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
339f8eef
import
itertools
import
itertools
import
os
import
time
from
pathlib
import
Path
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel.test_utils
import
(
assert_all_close_or_tiny_diff
,
create_per_token_group_quant_test_data
,
)
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
)
)
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.layers.quantization.utils
import
assert_fp8_all_close
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
configs
=
list
(
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_dim, group_size, dst_dtype, flags"
,
list
(
itertools
.
product
(
itertools
.
product
(
[
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
1
,
4
,
16
,
64
,
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
128
,
256
,
384
,
512
,
1024
,
1536
,
1664
,
2048
,
4096
,
7168
,
16384
],
# hidden_dim
[
8
,
16
,
32
,
64
,
128
],
# group_size
[
16
,
32
,
64
,
128
],
# group_size
#
TODO test int8
[
None
],
#
num_ranks
[
fp8_type_
],
# dtype
[
fp8_type_
,
torch
.
int8
],
# dtype
[
[
dict
(
dict
(
column_major_scales
=
False
,
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
scale_ue8m0
=
False
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
dict
(
dict
(
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
False
,
masked_layout_mode
=
None
,
),
),
],
],
)
)
)
+
list
(
itertools
.
product
(
[
1
,
4
,
1
*
8
,
4
*
8
,
64
*
8
,
256
*
8
,
768
*
8
],
# TODO support more
[
2048
],
[
128
],
[
8
,
16
,
32
,
48
],
[
fp8_type_
],
[
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
None
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"balanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"imbalanced"
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
fuse_silu_and_mul
=
True
,
masked_layout_mode
=
"extreme"
,
),
),
],
)
)
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags"
,
configs
)
)
def
test_per_token_group_quant_with_column_major
(
def
test_per_token_group_quant_with_column_major
(
num_tokens
,
num_tokens
,
hidden_dim
,
hidden_dim
,
group_size
,
group_size
,
num_ranks
,
dst_dtype
,
dst_dtype
,
flags
,
flags
,
):
):
if
flags
[
"scale_ue8m0"
]
and
((
group_size
!=
128
)
or
(
hidden_dim
%
512
!=
0
)):
print
(
pytest
.
skip
()
f
"
{
num_tokens
=
}
{
hidden_dim
=
}
{
group_size
=
}
{
num_ranks
=
}
{
dst_dtype
=
}
{
flags
=
}
"
)
arch_major
,
_
=
torch
.
cuda
.
get_device_capability
(
torch
.
cuda
.
current_device
())
if
flags
[
"scale_ue8m0"
]
and
(
arch_major
<=
9
):
pytest
.
skip
(
"Only Blackwell need ue8m0 fusion"
)
return
return
if
flags
[
"scale_ue8m0"
]
and
not
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
:
pytest
.
skip
(
"scale_ue8m0 only supported on Blackwell"
)
if
(
flags
[
"scale_ue8m0"
]
and
(
group_size
!=
128
))
or
(
(
dst_dtype
==
torch
.
int8
)
and
flags
[
"column_major_scales"
]
):
pytest
.
skip
()
return
return
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x
,
masked_m
=
create_per_token_group_quant_test_data
(
num_tokens
=
num_tokens
,
hidden_dim
=
hidden_dim
,
num_ranks
=
num_ranks
,
flags
=
flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs
=
dict
(
execute_kwargs
=
dict
(
x
=
x
,
x
=
x
,
masked_m
=
masked_m
,
group_size
=
group_size
,
group_size
=
group_size
,
eps
=
1e-10
,
eps
=
1e-10
,
dst_dtype
=
dst_dtype
,
dst_dtype
=
dst_dtype
,
**
flags
,
**
{
k
:
v
for
k
,
v
in
flags
.
items
()
if
k
not
in
[
"masked_layout_mode"
]}
,
)
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
def
_postprocess
(
x_q
,
x_s
):
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
if
masked_m
is
not
None
:
print
(
f
"Mask tokens after
{
masked_m
}
to be zero"
)
for
i
in
range
(
len
(
masked_m
)):
x_q
[
i
,
masked_m
[
i
]
:,
:]
=
0
x_s
[
i
,
masked_m
[
i
]
:,
:]
=
0
return
x_q
,
x_s
# torch.set_p
ri
n
to
ptions(profile="full")
x_q_triton
,
x_s_t
rito
n
=
_postprocess
(
# print(f"{x_q_triton=}"
)
*
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
# print(f"{x_s_triton=}"
)
)
# print(f"{
x_
q
_sglang
=}")
x_q_sglang
,
x_
s
_sglang
=
_postprocess
(
# print(f"{x_s_sglang=}"
)
*
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
# torch.set_printoptions(profile="default"
)
)
assert_fp8_all_close
(
x_q_triton
,
x_q_sglang
)
try
:
assert_all_close_or_tiny_diff
(
x_q_triton
,
x_q_sglang
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
x_s_triton
.
contiguous
(),
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
x_s_sglang
.
contiguous
(),
...
@@ -91,6 +165,35 @@ def test_per_token_group_quant_with_column_major(
...
@@ -91,6 +165,35 @@ def test_per_token_group_quant_with_column_major(
atol
=
1e-5
,
atol
=
1e-5
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
)
)
except
AssertionError
:
# torch.set_printoptions(profile="full")
print
(
f
"
{
x
.
shape
=
}
{
x_q_triton
.
shape
=
}
{
x_s_triton
.
shape
=
}
{
x_q_sglang
.
shape
=
}
{
x_s_sglang
.
shape
=
}
"
)
print
(
f
"
{
x
=
}
"
)
print
(
f
"
{
masked_m
=
}
"
)
print
(
f
"
{
x_q_triton
=
}
"
)
print
(
f
"
{
x_s_triton
=
}
"
)
print
(
f
"
{
x_q_sglang
=
}
"
)
print
(
f
"
{
x_s_sglang
=
}
"
)
# torch.set_printoptions(profile="default")
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
# import matplotlib.pyplot as plt
#
# base_stem = time.time()
# for name, value in [
# ("x_q", x_q_triton != x_q_sglang),
# ("x_s", x_s_triton != x_s_sglang),
# ]:
# value = value.reshape((-1, value.shape[-1]))
# plt.figure(figsize=(20, 20))
# plt.imshow((value * 1.0).cpu().numpy())
# p = Path(d) / f"{base_stem}_{name}.png"
# print(f"Write diff to {p}", flush=True)
# plt.savefig(p)
raise
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment