Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b1b3f0b3
Unverified
Commit
b1b3f0b3
authored
Aug 23, 2025
by
fzyzcjy
Committed by
GitHub
Aug 23, 2025
Browse files
Partially unify triton per token group quant kernels (#9485)
parent
34e5e11f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
161 additions
and
57 deletions
+161
-57
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+161
-57
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
b1b3f0b3
...
@@ -113,7 +113,7 @@ if supports_custom_op():
...
@@ -113,7 +113,7 @@ if supports_custom_op():
@
triton
.
jit
@
triton
.
jit
def
_per_token_group_quant_
fp8
(
def
_per_token_group_quant_
8bit
(
# Pointers to inputs and output
# Pointers to inputs and output
y_ptr
,
y_ptr
,
y_q_ptr
,
y_q_ptr
,
...
@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8(
...
@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8(
# Avoid to divide zero
# Avoid to divide zero
eps
,
eps
,
# Information for float8
# Information for float8
fp
8_min
,
bit
8_min
,
fp
8_max
,
bit
8_max
,
# Meta-parameters
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
):
...
@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8(
...
@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp
8_max
y_s
=
_absmax
/
bit
8_max
y_s_inv
=
1.0
/
y_s
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
fp
8_min
,
fp
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
bit
8_min
,
bit
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
@
triton
.
jit
def
_per_token_group_quant_
fp8
_colmajor
(
def
_per_token_group_quant_
8bit
_colmajor
(
# Pointers to inputs and output
# Pointers to inputs and output
y_ptr
,
y_ptr
,
y_q_ptr
,
y_q_ptr
,
...
@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor(
...
@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor(
# Avoid to divide zero
# Avoid to divide zero
eps
,
eps
,
# Information for float8
# Information for float8
fp
8_min
,
bit
8_min
,
fp
8_max
,
bit
8_max
,
# Meta-parameters
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
SCALE_UE8M0
:
tl
.
constexpr
,
SCALE_UE8M0
:
tl
.
constexpr
,
...
@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor(
...
@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp
8_max
y_s
=
_absmax
/
bit
8_max
if
SCALE_UE8M0
:
if
SCALE_UE8M0
:
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
tl
.
abs
(
y_s
))))
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
tl
.
abs
(
y_s
))))
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp
8_min
,
fp
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
bit
8_min
,
bit
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_
fp8
(
def
_
per_token_group_quant_
8bit_raw
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_dtype
,
column_major_scales
:
bool
=
False
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
...
@@ -223,6 +224,7 @@ def per_token_group_quant_fp8(
...
@@ -223,6 +224,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
...
@@ -232,7 +234,21 @@ def per_token_group_quant_fp8(
...
@@ -232,7 +234,21 @@ def per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
if
_is_hip
:
if
dtype
==
torch
.
int8
:
bit8_max
=
127.0
else
:
bit8_max
=
224.0
bit8_min
=
-
bit8_max
# TODO incorrect for int8
else
:
if
dtype
==
torch
.
int8
:
info
=
torch
.
iinfo
(
dtype
)
else
:
info
=
torch
.
finfo
(
dtype
)
bit8_max
=
info
.
max
bit8_min
=
info
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_shape
=
x
.
shape
,
x_shape
=
x
.
shape
,
device
=
x
.
device
,
device
=
x
.
device
,
...
@@ -250,7 +266,7 @@ def per_token_group_quant_fp8(
...
@@ -250,7 +266,7 @@ def per_token_group_quant_fp8(
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
num_stages
=
1
if
column_major_scales
:
if
column_major_scales
:
_per_token_group_quant_
fp8
_colmajor
[(
M
,)](
_per_token_group_quant_
8bit
_colmajor
[(
M
,)](
x
,
x
,
x_q
,
x_q
,
x_s
,
x_s
,
...
@@ -258,8 +274,8 @@ def per_token_group_quant_fp8(
...
@@ -258,8 +274,8 @@ def per_token_group_quant_fp8(
x
.
shape
[
1
],
x
.
shape
[
1
],
x_s
.
stride
(
1
),
x_s
.
stride
(
1
),
eps
,
eps
,
fp
8_min
=
fp
8_min
,
bit
8_min
=
bit
8_min
,
fp
8_max
=
fp
8_max
,
bit
8_max
=
bit
8_max
,
BLOCK
=
BLOCK
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
@@ -267,15 +283,15 @@ def per_token_group_quant_fp8(
...
@@ -267,15 +283,15 @@ def per_token_group_quant_fp8(
)
)
else
:
else
:
assert
not
scale_ue8m0
assert
not
scale_ue8m0
_per_token_group_quant_
fp8
[(
M
,)](
_per_token_group_quant_
8bit
[(
M
,)](
x
,
x
,
x_q
,
x_q
,
x_s
,
x_s
,
group_size
,
group_size
,
N
,
N
,
eps
,
eps
,
fp
8_min
=
fp
8_min
,
bit
8_min
=
bit
8_min
,
fp
8_max
=
fp
8_max
,
bit
8_max
=
bit
8_max
,
BLOCK
=
BLOCK
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
@@ -297,6 +313,117 @@ def per_token_group_quant_fp8(
...
@@ -297,6 +313,117 @@ def per_token_group_quant_fp8(
return
x_q
,
x_s
return
x_q
,
x_s
# backward compatibility
per_token_group_quant_fp8
=
_per_token_group_quant_8bit_raw
def
_per_token_group_quant_8bit_fuse_silu_and_mul
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
column_major_scales
:
bool
,
scale_tma_aligned
:
bool
,
scale_ue8m0
:
bool
,
masked_m
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Another way to implement (can be used in e.g. comparison tests)
# from sgl_kernel import silu_and_mul
# x_after_silu_and_mul = silu_and_mul(x)
# return per_token_group_quant_fp8(
# x_after_silu_and_mul,
# group_size=group_size,
# eps=eps,
# column_major_scales=column_major_scales,
# scale_tma_aligned=scale_tma_aligned,
# scale_ue8m0=scale_ue8m0,
# )
from
deep_gemm.utils.layout
import
transform_sf_into_required_layout
from
sglang.srt.layers.moe.ep_moe.kernels
import
silu_and_mul_masked_post_quant_fwd
assert
column_major_scales
assert
scale_tma_aligned
assert
scale_ue8m0
needs_unsqueeze
=
x
.
dim
()
==
2
if
needs_unsqueeze
:
num_tokens
,
_
=
x
.
shape
x
=
x
.
unsqueeze
(
0
)
assert
masked_m
is
None
masked_m
=
torch
.
tensor
([
num_tokens
],
device
=
x
.
device
,
dtype
=
torch
.
int32
)
# Use `zeros` for easier testing
output
=
torch
.
zeros
(
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
2
),
device
=
x
.
device
,
dtype
=
dst_dtype
,
)
# Use `zeros` for easier testing
output_scale_for_kernel
=
torch
.
zeros
(
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
2
//
group_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
silu_and_mul_masked_post_quant_fwd
(
input
=
x
,
output
=
output
,
output_scale
=
output_scale_for_kernel
,
quant_group_size
=
group_size
,
masked_m
=
masked_m
,
scale_ue8m0
=
scale_ue8m0
,
)
assert
group_size
==
128
output_scale
=
transform_sf_into_required_layout
(
output_scale_for_kernel
,
num_groups
=
output
.
shape
[
0
],
mn
=
output
.
shape
[
-
2
],
k
=
output
.
shape
[
-
1
],
recipe
=
(
1
,
group_size
,
group_size
),
is_sfa
=
True
,
)
if
needs_unsqueeze
:
output
=
output
.
squeeze
(
0
)
output_scale
=
output_scale
.
squeeze
(
0
)
return
output
,
output_scale
def
per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
fuse_silu_and_mul
:
return
_per_token_group_quant_8bit_fuse_silu_and_mul
(
x
=
x
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
masked_m
=
masked_m
,
)
else
:
return
_per_token_group_quant_8bit_raw
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
dtype
=
dst_dtype
,
)
def
create_per_token_group_quant_fp8_output_scale
(
def
create_per_token_group_quant_fp8_output_scale
(
x_shape
,
x_shape
,
device
,
device
,
...
@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale(
...
@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale(
):
):
if
scale_ue8m0
:
if
scale_ue8m0
:
assert
column_major_scales
and
scale_tma_aligned
assert
column_major_scales
and
scale_tma_aligned
x_q_mn
,
x_q_k
=
x_shape
*
x_batch
,
x_q_mn
,
x_q_k
=
x_shape
x_s_mn
,
x_s_k
=
x_q_mn
,
x_q_k
//
128
x_s_mn
,
x_s_k
=
x_q_mn
,
x_q_k
//
128
aligned_mn
=
align
(
x_s_mn
,
4
)
aligned_mn
=
align
(
x_s_mn
,
4
)
aligned_k
=
align
(
x_s_k
,
4
)
aligned_k
=
align
(
x_s_k
,
4
)
# TODO(FIXME): Fix cuda kernel and recover here to empty.
# TODO(FIXME): Fix cuda kernel and recover here to empty.
return
torch
.
zeros
(
return
torch
.
empty
(
(
aligned_k
//
4
,
aligned_mn
),
(
*
x_batch
,
aligned_k
//
4
,
aligned_mn
),
device
=
device
,
device
=
device
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)[
:
x_s_mn
,
:]
).
transpose
(
-
1
,
-
2
)[...,
:
x_s_mn
,
:]
elif
column_major_scales
:
elif
column_major_scales
:
if
scale_tma_aligned
:
if
scale_tma_aligned
:
# TODO extract "align" function
# TODO extract "align" function
...
@@ -341,39 +468,6 @@ def create_per_token_group_quant_fp8_output_scale(
...
@@ -341,39 +468,6 @@ def create_per_token_group_quant_fp8_output_scale(
)
)
# TODO maybe unify int8 and fp8 code later
def
per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_group_quant_int8
if
dst_dtype
==
torch
.
int8
:
assert
not
column_major_scales
assert
not
scale_tma_aligned
assert
not
scale_ue8m0
return
per_token_group_quant_int8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
dtype
=
dst_dtype
,
)
return
per_token_group_quant_fp8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
)
def
sglang_per_token_group_quant_fp8
(
def
sglang_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
...
@@ -381,15 +475,19 @@ def sglang_per_token_group_quant_fp8(
...
@@ -381,15 +475,19 @@ def sglang_per_token_group_quant_fp8(
column_major_scales
:
bool
=
False
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
assert
(
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
out_shape
=
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
(
2
if
fuse_silu_and_mul
else
1
))
x_q
=
torch
.
empty
(
out_shape
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_shape
=
x
.
shape
,
x_shape
=
out_
shape
,
device
=
x
.
device
,
device
=
x
.
device
,
group_size
=
group_size
,
group_size
=
group_size
,
column_major_scales
=
column_major_scales
,
column_major_scales
=
column_major_scales
,
...
@@ -414,6 +512,8 @@ def sglang_per_token_group_quant_8bit(
...
@@ -414,6 +512,8 @@ def sglang_per_token_group_quant_8bit(
column_major_scales
:
bool
=
False
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
from
sglang.srt.layers.quantization.int8_kernel
import
(
from
sglang.srt.layers.quantization.int8_kernel
import
(
sglang_per_token_group_quant_int8
,
sglang_per_token_group_quant_int8
,
...
@@ -422,6 +522,8 @@ def sglang_per_token_group_quant_8bit(
...
@@ -422,6 +522,8 @@ def sglang_per_token_group_quant_8bit(
if
dst_dtype
==
torch
.
int8
:
if
dst_dtype
==
torch
.
int8
:
assert
not
column_major_scales
assert
not
column_major_scales
assert
not
scale_tma_aligned
assert
not
scale_tma_aligned
assert
not
fuse_silu_and_mul
assert
masked_m
is
None
return
sglang_per_token_group_quant_int8
(
return
sglang_per_token_group_quant_int8
(
x
=
x
,
x
=
x
,
group_size
=
group_size
,
group_size
=
group_size
,
...
@@ -436,6 +538,8 @@ def sglang_per_token_group_quant_8bit(
...
@@ -436,6 +538,8 @@ def sglang_per_token_group_quant_8bit(
column_major_scales
=
column_major_scales
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
scale_ue8m0
=
scale_ue8m0
,
fuse_silu_and_mul
=
fuse_silu_and_mul
,
masked_m
=
masked_m
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment