Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b58c3c28
Unverified
Commit
b58c3c28
authored
Jul 28, 2025
by
fzyzcjy
Committed by
GitHub
Jul 27, 2025
Browse files
Support ue8m0 for triton quant kernel (#7603)
parent
df906455
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
48 deletions
+78
-48
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+78
-48
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
b58c3c28
...
...
@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor(
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
SCALE_UE8M0
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
...
...
@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor(
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
if
SCALE_UE8M0
:
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
tl
.
abs
(
y_s
))))
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
...
...
@@ -209,6 +212,7 @@ def per_token_group_quant_fp8(
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
...
...
@@ -229,29 +233,17 @@ def per_token_group_quant_fp8(
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_shape
=
x
.
shape
,
device
=
x
.
device
,
group_size
=
group_size
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
False
,
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
if
scale_tma_aligned
:
# aligned to 4 * sizeof(float)
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
...
...
@@ -271,8 +263,10 @@ def per_token_group_quant_fp8(
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
SCALE_UE8M0
=
scale_ue8m0
,
)
else
:
assert
not
scale_ue8m0
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
...
...
@@ -287,57 +281,93 @@ def per_token_group_quant_fp8(
num_stages
=
num_stages
,
)
if
scale_ue8m0
:
from
deep_gemm.utils.layout
import
transform_sf_into_required_layout
assert
group_size
==
128
x_s
=
transform_sf_into_required_layout
(
x_s
,
num_groups
=
None
,
mn
=
x_q
.
shape
[
0
],
k
=
x_q
.
shape
[
1
],
recipe
=
(
1
,
group_size
,
group_size
),
is_sfa
=
True
,
)
return
x_q
,
x_s
def
sglang
_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
def
create
_per_token_group_quant_fp8
_output_scale
(
x
_shape
,
device
,
group_size
,
column_major_scales
:
bool
,
scale_tma_aligned
:
bool
,
scale_ue8m0
:
bool
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
if
scale_ue8m0
:
assert
column_major_scales
and
scale_tma_aligned
x_q_mn
,
x_q_k
=
x
.
shape
x_q_mn
,
x_q_k
=
x
_
shape
x_s_mn
,
x_s_k
=
x_q_mn
,
x_q_k
//
128
aligned_mn
=
align
(
x_s_mn
,
4
)
aligned_k
=
align
(
x_s_k
,
4
)
# TODO(FIXME): Fix cuda kernel and recover here to empty.
x_s
=
torch
.
zeros
(
return
torch
.
zeros
(
(
aligned_k
//
4
,
aligned_mn
),
device
=
x
.
device
,
device
=
device
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)[:
x_s_mn
,
:]
elif
column_major_scales
:
if
scale_tma_aligned
:
# TODO extract "align" function
# aligned to 4 * sizeof(float)
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
aligned_size
=
(
x
_
shape
[
-
2
]
+
3
)
//
4
*
4
return
torch
.
empty
(
x
_
shape
[:
-
2
]
+
(
x
_
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
).
permute
(
-
1
,
-
2
)[:
x
_
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
return
torch
.
empty
(
(
x
_
shape
[
-
1
]
//
group_size
,)
+
x
_
shape
[:
-
1
],
device
=
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
return
torch
.
empty
(
x
_
shape
[:
-
1
]
+
(
x
_
shape
[
-
1
]
//
group_size
,),
device
=
device
,
dtype
=
torch
.
float32
,
)
def
sglang_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
scale_ue8m0
:
bool
=
False
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
if
scale_ue8m0
:
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
assert
x
.
shape
[
-
1
]
%
(
group_size
*
4
)
==
0
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_dtype
)
x_s
=
create_per_token_group_quant_fp8_output_scale
(
x_shape
=
x
.
shape
,
device
=
x
.
device
,
group_size
=
group_size
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
scale_ue8m0
=
scale_ue8m0
,
)
if
x
.
shape
[
0
]
>
0
:
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment