Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b1b3f0b3
"vscode:/vscode.git/clone" did not exist on "6b4d97f1d05a56c209052f2d5e0a8da0fc2e6d03"
Unverified
Commit
b1b3f0b3
authored
Aug 23, 2025
by
fzyzcjy
Committed by
GitHub
Aug 23, 2025
Browse files
Partially unify triton per token group quant kernels (#9485)
parent
34e5e11f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
161 additions
and
57 deletions
+161
-57
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+161
-57
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
b1b3f0b3
...
...
@@ -113,7 +113,7 @@ if supports_custom_op():
@
triton
.
jit
def
_per_token_group_quant_
fp8
(
def
_per_token_group_quant_
8bit
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
...
...
@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8(
# Avoid to divide zero
eps
,
# Information for float8
fp
8_min
,
fp
8_max
,
bit
8_min
,
bit
8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
...
...
@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp
8_max
y_s
=
_absmax
/
bit
8_max
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
fp
8_min
,
fp
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
bit
8_min
,
bit
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_
fp8
_colmajor
(
def
_per_token_group_quant_
8bit
_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
...
...
@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor(
# Avoid to divide zero
eps
,
# Information for float8
fp
8_min
,
fp
8_max
,
bit
8_min
,
bit
8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
SCALE_UE8M0
:
tl
.
constexpr
,
...
...
@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp
8_max
y_s
=
_absmax
/
bit
8_max
if
SCALE_UE8M0
:
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
tl
.
abs
(
y_s
))))
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp
8_min
,
fp
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
bit
8_min
,
bit
8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_
fp8
(
def
_
per_token_group_quant_
8bit_raw
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_dtype
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
...
...
@@ -223,6 +224,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
...
...
@@ -232,7 +234,21 @@ def per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
if
_is_hip
:
if
dtype
==
torch
.
int8
:
bit8_max
=
127.0
else
:
bit8_max
=
224.0
bit8_min
=
-
bit8_max
# TODO incorrect for int8
else
:
if
dtype
==
torch
.
int8
:
info
=
torch
.
iinfo
(
dtype
)
else
:
info
=
torch
.
finfo
(
dtype
)
bit8_max
=
info
.
max
bit8_min
=
info
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_shape
=
x
.
shape
,
device
=
x
.
device
,
...
...
@@ -250,7 +266,7 @@ def per_token_group_quant_fp8(
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
if
column_major_scales
:
_per_token_group_quant_
fp8
_colmajor
[(
M
,)](
_per_token_group_quant_
8bit
_colmajor
[(
M
,)](
x
,
x_q
,
x_s
,
...
...
@@ -258,8 +274,8 @@ def per_token_group_quant_fp8(
x
.
shape
[
1
],
x_s
.
stride
(
1
),
eps
,
fp
8_min
=
fp
8_min
,
fp
8_max
=
fp
8_max
,
bit
8_min
=
bit
8_min
,
bit
8_max
=
bit
8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
@@ -267,15 +283,15 @@ def per_token_group_quant_fp8(
)
else
:
assert
not
scale_ue8m0
_per_token_group_quant_
fp8
[(
M
,)](
_per_token_group_quant_
8bit
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp
8_min
=
fp
8_min
,
fp
8_max
=
fp
8_max
,
bit
8_min
=
bit
8_min
,
bit
8_max
=
bit
8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
@@ -297,6 +313,117 @@ def per_token_group_quant_fp8(
return
x_q
,
x_s
# backward compatibility
per_token_group_quant_fp8
=
_per_token_group_quant_8bit_raw
def
_per_token_group_quant_8bit_fuse_silu_and_mul
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
column_major_scales
:
bool
,
scale_tma_aligned
:
bool
,
scale_ue8m0
:
bool
,
masked_m
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Another way to implement (can be used in e.g. comparison tests)
# from sgl_kernel import silu_and_mul
# x_after_silu_and_mul = silu_and_mul(x)
# return per_token_group_quant_fp8(
# x_after_silu_and_mul,
# group_size=group_size,
# eps=eps,
# column_major_scales=column_major_scales,
# scale_tma_aligned=scale_tma_aligned,
# scale_ue8m0=scale_ue8m0,
# )
from
deep_gemm.utils.layout
import
transform_sf_into_required_layout
from
sglang.srt.layers.moe.ep_moe.kernels
import
silu_and_mul_masked_post_quant_fwd
assert
column_major_scales
assert
scale_tma_aligned
assert
scale_ue8m0
needs_unsqueeze
=
x
.
dim
()
==
2
if
needs_unsqueeze
:
num_tokens
,
_
=
x
.
shape
x
=
x
.
unsqueeze
(
0
)
assert
masked_m
is
None
masked_m
=
torch
.
tensor
([
num_tokens
],
device
=
x
.
device
,
dtype
=
torch
.
int32
)
# Use `zeros` for easier testing
output
=
torch
.
zeros
(
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
2
),
device
=
x
.
device
,
dtype
=
dst_dtype
,
)
# Use `zeros` for easier testing
output_scale_for_kernel
=
torch
.
zeros
(
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
2
//
group_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
silu_and_mul_masked_post_quant_fwd
(
input
=
x
,
output
=
output
,
output_scale
=
output_scale_for_kernel
,
quant_group_size
=
group_size
,
masked_m
=
masked_m
,
scale_ue8m0
=
scale_ue8m0
,
)
assert
group_size
==
128
output_scale
=
transform_sf_into_required_layout
(
output_scale_for_kernel
,
num_groups
=
output
.
shape
[
0
],
mn
=
output
.
shape
[
-
2
],
k
=
output
.
shape
[
-
1
],
recipe
=
(
1
,
group_size
,
group_size
),
is_sfa
=
True
,
)
if
needs_unsqueeze
:
output
=
output
.
squeeze
(
0
)
output_scale
=
output_scale
.
squeeze
(
0
)
return
output
,
output_scale
def
per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
fuse_silu_and_mul
:
return
_per_token_group_quant_8bit_fuse_silu_and_mul
(
x
=
x
,
group_size
=
group_size
,
dst_dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
masked_m
=
masked_m
,
)
else
:
return
_per_token_group_quant_8bit_raw
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
dtype
=
dst_dtype
,
)
def
create_per_token_group_quant_fp8_output_scale
(
x_shape
,
device
,
...
...
@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale(
):
if
scale_ue8m0
:
assert
column_major_scales
and
scale_tma_aligned
x_q_mn
,
x_q_k
=
x_shape
*
x_batch
,
x_q_mn
,
x_q_k
=
x_shape
x_s_mn
,
x_s_k
=
x_q_mn
,
x_q_k
//
128
aligned_mn
=
align
(
x_s_mn
,
4
)
aligned_k
=
align
(
x_s_k
,
4
)
# TODO(FIXME): Fix cuda kernel and recover here to empty.
return
torch
.
zeros
(
(
aligned_k
//
4
,
aligned_mn
),
return
torch
.
empty
(
(
*
x_batch
,
aligned_k
//
4
,
aligned_mn
),
device
=
device
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)[
:
x_s_mn
,
:]
).
transpose
(
-
1
,
-
2
)[...,
:
x_s_mn
,
:]
elif
column_major_scales
:
if
scale_tma_aligned
:
# TODO extract "align" function
...
...
@@ -341,39 +468,6 @@ def create_per_token_group_quant_fp8_output_scale(
)
# TODO maybe unify int8 and fp8 code later
def
per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_group_quant_int8
if
dst_dtype
==
torch
.
int8
:
assert
not
column_major_scales
assert
not
scale_tma_aligned
assert
not
scale_ue8m0
return
per_token_group_quant_int8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
dtype
=
dst_dtype
,
)
return
per_token_group_quant_fp8
(
x
=
x
,
group_size
=
group_size
,
eps
=
eps
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
)
def
sglang_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
...
...
@@ -381,15 +475,19 @@ def sglang_per_token_group_quant_fp8(
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
out_shape
=
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
(
2
if
fuse_silu_and_mul
else
1
))
x_q
=
torch
.
empty
(
out_shape
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_shape
=
x
.
shape
,
x_shape
=
out_
shape
,
device
=
x
.
device
,
group_size
=
group_size
,
column_major_scales
=
column_major_scales
,
...
...
@@ -414,6 +512,8 @@ def sglang_per_token_group_quant_8bit(
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
fuse_silu_and_mul
:
bool
=
False
,
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
,
):
from
sglang.srt.layers.quantization.int8_kernel
import
(
sglang_per_token_group_quant_int8
,
...
...
@@ -422,6 +522,8 @@ def sglang_per_token_group_quant_8bit(
if
dst_dtype
==
torch
.
int8
:
assert
not
column_major_scales
assert
not
scale_tma_aligned
assert
not
fuse_silu_and_mul
assert
masked_m
is
None
return
sglang_per_token_group_quant_int8
(
x
=
x
,
group_size
=
group_size
,
...
...
@@ -436,6 +538,8 @@ def sglang_per_token_group_quant_8bit(
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
fuse_silu_and_mul
=
fuse_silu_and_mul
,
masked_m
=
masked_m
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment