Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4efe844a
"vscode:/vscode.git/clone" did not exist on "2bc4df229e44d37ad42c6078fe018ebc2a920af2"
Unverified
Commit
4efe844a
authored
Sep 06, 2025
by
Morpheus Guo
Committed by
GitHub
Sep 05, 2025
Browse files
enable aiter gemm_a8w8_bpreshuffle for ptpc gemm (#8555)
parent
bde73ee4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
20 deletions
+56
-20
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+13
-1
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+43
-19
No files found.
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
4efe844a
...
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_use_aiter
:
from
aiter.ops.shuffle
import
shuffle_weight
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
...
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else
:
else
:
weight_scale
=
layer
.
weight_scale
.
data
weight_scale
=
layer
.
weight_scale
.
data
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
_use_aiter
:
layer
.
weight
=
Parameter
(
shuffle_weight
(
weight
,
(
16
,
16
)),
requires_grad
=
False
)
else
:
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
4efe844a
...
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if
_use_aiter
:
if
_use_aiter
:
import
aiter
import
aiter
from
aiter
import
gemm_a8w8_blockscale
,
get_hip_quant
from
aiter
import
gemm_a8w8_blockscale
,
gemm_a8w8_bpreshuffle
,
get_hip_quant
aiter_per1x128_quant
=
get_hip_quant
(
aiter
.
QuantType
.
per_1x128
)
aiter_per1x128_quant
=
get_hip_quant
(
aiter
.
QuantType
.
per_1x128
)
...
@@ -642,25 +642,49 @@ def apply_fp8_linear(
...
@@ -642,25 +642,49 @@ def apply_fp8_linear(
use_per_token_if_dynamic
use_per_token_if_dynamic
and
not
per_tensor_weights
and
not
per_tensor_weights
and
not
per_tensor_activations
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
and
(
USE_ROWWISE_TORCH_SCALED_MM
or
_use_aiter
)
):
):
# For now validated on ROCm platform
# into this sector means use dynamic per-token-per-channel quant
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# per-token scale quant for input matrix, every row(one token) have one scale factor
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
# and ROCm 6.3, which only exists in torch 2.7 and above.
if
_use_aiter
:
# For CUDA platform please validate if the
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
# torch._scaled_mm support rowwise scaled GEMM
# XQ -> input tensor, shape = (m, k)
# Fused GEMM_DQ Rowwise GEMM
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
output
=
torch
.
_scaled_mm
(
# x_scale -> input scale tensor, shape = (m, 1)
qinput
,
# w_scale -> weight scale tensor, shape = (n ,1)
weight
,
# dtype -> output dtype
out_dtype
=
input
.
dtype
,
output
=
gemm_a8w8_bpreshuffle
(
scale_a
=
x_scale
,
XQ
=
qinput
,
scale_b
=
weight_scale
.
t
(),
WQ
=
weight
,
bias
=
bias
,
x_scale
=
x_scale
,
)
w_scale
=
weight_scale
,
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
output_shape
)
dtype
=
input
.
dtype
,
)
if
bias
is
not
None
:
output
+=
bias
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
)
else
:
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
.
t
(),
bias
=
bias
,
)
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
output_shape
)
else
:
else
:
# Fallback for channelwise case, where we use unfused DQ
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# due to limitations with scaled_mm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment