Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4efe844a
Unverified
Commit
4efe844a
authored
Sep 06, 2025
by
Morpheus Guo
Committed by
GitHub
Sep 05, 2025
Browse files
enable aiter gemm_a8w8_bpreshuffle for ptpc gemm (#8555)
parent
bde73ee4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
20 deletions
+56
-20
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+13
-1
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+43
-19
No files found.
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
4efe844a
...
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_use_aiter
:
from
aiter.ops.shuffle
import
shuffle_weight
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
...
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else
:
else
:
weight_scale
=
layer
.
weight_scale
.
data
weight_scale
=
layer
.
weight_scale
.
data
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
_use_aiter
:
layer
.
weight
=
Parameter
(
shuffle_weight
(
weight
,
(
16
,
16
)),
requires_grad
=
False
)
else
:
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
4efe844a
...
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if
_use_aiter
:
if
_use_aiter
:
import
aiter
import
aiter
from
aiter
import
gemm_a8w8_blockscale
,
get_hip_quant
from
aiter
import
gemm_a8w8_blockscale
,
gemm_a8w8_bpreshuffle
,
get_hip_quant
aiter_per1x128_quant
=
get_hip_quant
(
aiter
.
QuantType
.
per_1x128
)
aiter_per1x128_quant
=
get_hip_quant
(
aiter
.
QuantType
.
per_1x128
)
...
@@ -642,25 +642,49 @@ def apply_fp8_linear(
...
@@ -642,25 +642,49 @@ def apply_fp8_linear(
use_per_token_if_dynamic
use_per_token_if_dynamic
and
not
per_tensor_weights
and
not
per_tensor_weights
and
not
per_tensor_activations
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
and
(
USE_ROWWISE_TORCH_SCALED_MM
or
_use_aiter
)
):
):
# For now validated on ROCm platform
# into this sector means use dynamic per-token-per-channel quant
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# per-token scale quant for input matrix, every row(one token) have one scale factor
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
# and ROCm 6.3, which only exists in torch 2.7 and above.
if
_use_aiter
:
# For CUDA platform please validate if the
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
# torch._scaled_mm support rowwise scaled GEMM
# XQ -> input tensor, shape = (m, k)
# Fused GEMM_DQ Rowwise GEMM
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
output
=
torch
.
_scaled_mm
(
# x_scale -> input scale tensor, shape = (m, 1)
qinput
,
# w_scale -> weight scale tensor, shape = (n ,1)
weight
,
# dtype -> output dtype
out_dtype
=
input
.
dtype
,
output
=
gemm_a8w8_bpreshuffle
(
scale_a
=
x_scale
,
XQ
=
qinput
,
scale_b
=
weight_scale
.
t
(),
WQ
=
weight
,
bias
=
bias
,
x_scale
=
x_scale
,
)
w_scale
=
weight_scale
,
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
output_shape
)
dtype
=
input
.
dtype
,
)
if
bias
is
not
None
:
output
+=
bias
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
)
else
:
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
.
t
(),
bias
=
bias
,
)
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
output_shape
)
else
:
else
:
# Fallback for channelwise case, where we use unfused DQ
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# due to limitations with scaled_mm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment