Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e85cb1ce
Unverified
Commit
e85cb1ce
authored
Aug 21, 2025
by
fzyzcjy
Committed by
GitHub
Aug 21, 2025
Browse files
Fix quant kernel test errors and benchmark wrong output speeds (#7604)
parent
55d336cb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
205 additions
and
463 deletions
+205
-463
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+67
-0
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+21
-0
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+59
-183
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+58
-280
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
e85cb1ce
...
@@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale(
...
@@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale(
)
)
# TODO maybe unify int8 and fp8 code later
def
per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_group_quant_int8
if
dst_dtype
==
torch
.
int8
:
assert
not
column_major_scales
assert
not
scale_tma_aligned
assert
not
scale_ue8m0
return
per_token_group_quant_int8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
dtype
=
dst_dtype
,
)
return
per_token_group_quant_fp8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
)
def
sglang_per_token_group_quant_fp8
(
def
sglang_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
...
@@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8(
...
@@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8(
return
x_q
,
x_s
return
x_q
,
x_s
# TODO maybe unify int8 and fp8 code later
def
sglang_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
):
from
sglang.srt.layers.quantization.int8_kernel
import
(
sglang_per_token_group_quant_int8
,
)
if
dst_dtype
==
torch
.
int8
:
assert
not
column_major_scales
assert
not
scale_tma_aligned
return
sglang_per_token_group_quant_int8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
dtype
=
dst_dtype
,
)
return
sglang_per_token_group_quant_fp8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
)
def
sglang_per_token_quant_fp8
(
def
sglang_per_token_quant_fp8
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_dtype
,
dtype
:
torch
.
dtype
=
fp8_dtype
,
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
e85cb1ce
...
@@ -176,6 +176,27 @@ def replace_parameter(
...
@@ -176,6 +176,27 @@ def replace_parameter(
mod
.
register_parameter
(
name
,
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
))
mod
.
register_parameter
(
name
,
torch
.
nn
.
Parameter
(
new
,
requires_grad
=
False
))
def
assert_fp8_all_close
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
assert
a
.
shape
==
b
.
shape
assert
a
.
dtype
==
b
.
dtype
==
torch
.
float8_e4m3fn
a_u8
=
a
.
view
(
torch
.
uint8
)
b_u8
=
b
.
view
(
torch
.
uint8
)
diff_u8
=
(
a_u8
.
to
(
torch
.
int16
)
-
b_u8
.
to
(
torch
.
int16
)).
abs
()
numel
=
a
.
numel
()
count_diff_sign
=
((
a_u8
>=
0
)
&
(
b_u8
<
0
)).
sum
().
item
()
count_tiny_diff
=
(
diff_u8
>=
1
).
sum
().
item
()
count_large_diff
=
(
diff_u8
>=
2
).
sum
().
item
()
assert
(
(
count_diff_sign
==
0
)
and
(
count_tiny_diff
/
numel
<
0.005
)
and
(
count_large_diff
==
0
)
),
f
"
{
count_diff_sign
=
}
{
count_tiny_diff
=
}
{
count_large_diff
=
}
{
numel
=
}
"
# Match dynamic rules with module name (prefix) and override quantize
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
# config if module (prefix) matches a rule
def
override_config
(
config
:
QuantizationConfig
,
prefix
:
str
):
def
override_config
(
config
:
QuantizationConfig
,
prefix
:
str
):
...
...
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
e85cb1ce
import
itertools
import
itertools
from
typing
import
Tuple
import
time
from
functools
import
partial
from
pathlib
import
Path
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
from
sglang.srt.bench_utils
import
bench_kineto
from
sglang.srt.layers.quantization.fp8_kernel
import
(
create_per_token_group_quant_fp8_output_scale
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
@
triton
.
jit
num_tokens_range
=
[
1
,
4
,
16
,
64
,
256
,
768
,
2048
,
8192
,
16384
]
def
_per_token_group_quant_8bit
(
hidden_dim_range
=
[
1536
,
7168
,
18432
]
# For DeepSeek V3/R1
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Columns of input
N
,
# Avoid to divide zero
eps
,
# Information for 8bit data type (int8 or fp8_type_)
max_8bit
,
min_8bit
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into 8bit values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
max_8bit
y_q
=
tl
.
clamp
(
y
/
y_s
,
min_8bit
,
max_8bit
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
max_8bit
=
iinfo
.
max
min_8bit
=
iinfo
.
min
else
:
finfo
=
torch
.
finfo
(
dst_dtype
)
max_8bit
=
finfo
.
max
min_8bit
=
finfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_8bit
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
max_8bit
,
min_8bit
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
def
sglang_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
if
dst_dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
f8_info
=
torch
.
finfo
(
dst_dtype
)
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
*
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x
.
clone
(),
group_size
,
dst_dtype
)
if
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
f
"✅
{
dst_dtype
}
implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
group_size_range
=
[
128
]
# For DeepSeek V3/R1
group_size_range
=
[
128
]
# For DeepSeek V3/R1
dst_dtype_range
=
[
torch
.
int8
,
fp8_type_
]
# TODO test int8
dst_dtype_range
=
[
fp8_type_
]
flags_range
=
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
),
]
configs
=
list
(
configs
=
list
(
itertools
.
product
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
group_size_range
,
dst_dtype_range
num_tokens_range
,
hidden_dim_range
,
group_size_range
,
dst_dtype_range
,
flags_range
,
)
)
)
)
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"
batch_size"
,
"seq_len
"
,
"group_size"
,
"dst_dtype"
],
x_names
=
[
"
num_tokens"
,
"hidden_dim
"
,
"group_size"
,
"dst_dtype"
,
"flags"
],
x_vals
=
configs
,
x_vals
=
configs
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_vals
=
[
"triton"
,
"sglang"
],
...
@@ -194,29 +73,26 @@ configs = list(
...
@@ -194,29 +73,26 @@ configs = list(
args
=
{},
args
=
{},
)
)
)
)
def
benchmark
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
,
provider
):
def
benchmark
(
num_tokens
,
hidden_dim
,
group_size
,
dst_dtype
,
flags
,
provider
):
device
=
torch
.
device
(
"cuda"
)
if
flags
[
"scale_ue8m0"
]
and
group_size
!=
128
:
hidden_dim
=
7168
return
x
=
torch
.
randn
(
device
=
torch
.
device
(
"cuda"
)
batch_size
*
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"triton"
:
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
fn
=
lambda
:
triton_per_token_group_quant_8bit
(
x
,
group_size
,
dst_dtype
)
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_per_token_group_quant_8bit
(
x
,
group_size
,
dst_dtype
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
fn
,
kernel_names
=
{
"triton"
:
(
triton_per_token_group_quant_8bit
,
"_per_token_group_quant_fp8"
),
"sglang"
:
(
sglang_per_token_group_quant_8bit
,
"per_token_group_quant_8bit_kernel"
,
),
}[
provider
]
bench_fn
=
lambda
:
fn
(
x
=
x
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
**
flags
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
kernel_names
)
return
time_s
*
1e6
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
,
dst_dtype
=
torch
.
int8
)
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
,
dst_dtype
=
fp8_type_
)
benchmark
.
run
(
print_data
=
True
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
e85cb1ce
import
itertools
import
itertools
from
typing
import
Tuple
import
pytest
import
pytest
import
torch
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_group_quant_int8
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_8bit
as
triton_per_token_group_quant_8bit
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_8bit
from
sglang.srt.layers.quantization.utils
import
assert_fp8_all_close
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
@
triton
.
jit
def
_per_token_group_quant_fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Columns of input
N
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_fp8_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
# Stride from one column to the next of y_s
y_s_col_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
group_size
y_q_ptr
+=
g_id
*
group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row
=
y_num_columns
//
group_size
scale_col
=
g_id
%
blocks_per_row
scale_row
=
g_id
//
blocks_per_row
y_s_ptr
+=
scale_col
*
y_s_col_stride
+
scale_row
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
if
dtype
==
torch
.
int8
:
finfo
=
torch
.
iinfo
(
dtype
)
else
:
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
if
dtype
==
torch
.
int8
:
fp8_max
=
127.0
else
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
if
scale_tma_aligned
:
# aligned to 4 * sizeof(float)
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
if
column_major_scales
:
_per_token_group_quant_fp8_colmajor
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
x
.
shape
[
1
],
x_s
.
stride
(
1
),
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
else
:
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
def
sglang_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
if
scale_tma_aligned
:
# aligned to 4 * sizeof(float)
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
if
dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
f8_info
=
torch
.
finfo
(
dtype
)
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
scale_ue8m0
=
False
# TODO also test true
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
)
return
x_q
,
x_s
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_dim, group_size, dst_dtype,
column_major_scales, scale_tma_aligned
"
,
"num_tokens, hidden_dim, group_size, dst_dtype,
flags
"
,
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
127
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
8
,
16
,
32
,
64
,
128
],
# group_size
[
8
,
16
,
32
,
64
,
128
],
# group_size
[
torch
.
int8
,
fp8_type_
],
# dtype
# TODO test int8
[
False
,
True
],
# column_major_scales
[
fp8_type_
],
# dtype
[
False
,
True
],
# scale_tma_aligned
[
dict
(
column_major_scales
=
False
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
False
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
False
,
),
dict
(
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
True
,
),
],
)
)
),
),
)
)
...
@@ -281,37 +54,42 @@ def test_per_token_group_quant_with_column_major(
...
@@ -281,37 +54,42 @@ def test_per_token_group_quant_with_column_major(
hidden_dim
,
hidden_dim
,
group_size
,
group_size
,
dst_dtype
,
dst_dtype
,
column_major_scales
,
flags
,
scale_tma_aligned
,
):
):
if
not
column_major_scales
and
scale_tma_aligned
:
if
flags
[
"scale_ue8m0"
]
and
((
group_size
!=
128
)
or
(
hidden_dim
%
512
!=
0
)):
pytest
.
skip
()
return
if
flags
[
"scale_ue8m0"
]
and
not
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
:
pytest
.
skip
(
"scale_ue8m0 only supported on Blackwell"
)
return
return
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
b
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bi
t
(
execute_kwargs
=
dic
t
(
x
,
x
=
x
,
group_size
,
group_size
=
group_size
,
eps
=
1e-10
,
eps
=
1e-10
,
dtype
=
dst_dtype
,
dst_dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
**
flags
,
scale_tma_aligned
=
scale_tma_aligned
,
)
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
**
execute_kwargs
)
x
,
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
**
execute_kwargs
)
group_size
,
eps
=
1e-10
,
dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
)
# torch.set_printoptions(profile="full")
# print(f"{x_q_triton=}")
# print(f"{x_s_triton=}")
# print(f"{x_q_sglang=}")
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
assert_fp8_all_close
(
x_q_triton
,
x_q_sglang
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
x_s_triton
.
contiguous
(),
)
x_s_sglang
.
contiguous
(),
torch
.
testing
.
assert_close
(
rtol
=
1e-3
,
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
atol
=
1e-5
,
msg
=
lambda
message
:
message
+
f
"
{
x_s_triton
=
}
{
x_s_sglang
=
}
"
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment