Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4efe844a
"examples/vscode:/vscode.git/clone" did not exist on "b3841c254430b7d9f8f2ae4f33ad6e23236b0f86"
Unverified
Commit
4efe844a
authored
Sep 06, 2025
by
Morpheus Guo
Committed by
GitHub
Sep 05, 2025
Browse files
enable aiter gemm_a8w8_bpreshuffle for ptpc gemm (#8555)
parent
bde73ee4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
20 deletions
+56
-20
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+13
-1
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+43
-19
No files found.
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
4efe844a
...
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_use_aiter
:
from
aiter.ops.shuffle
import
shuffle_weight
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
...
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else
:
else
:
weight_scale
=
layer
.
weight_scale
.
data
weight_scale
=
layer
.
weight_scale
.
data
if
_use_aiter
:
layer
.
weight
=
Parameter
(
shuffle_weight
(
weight
,
(
16
,
16
)),
requires_grad
=
False
)
else
:
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
4efe844a
...
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if
_use_aiter
:
if
_use_aiter
:
import
aiter
import
aiter
from
aiter
import
gemm_a8w8_blockscale
,
get_hip_quant
from
aiter
import
gemm_a8w8_blockscale
,
gemm_a8w8_bpreshuffle
,
get_hip_quant
aiter_per1x128_quant
=
get_hip_quant
(
aiter
.
QuantType
.
per_1x128
)
aiter_per1x128_quant
=
get_hip_quant
(
aiter
.
QuantType
.
per_1x128
)
...
@@ -642,8 +642,31 @@ def apply_fp8_linear(
...
@@ -642,8 +642,31 @@ def apply_fp8_linear(
use_per_token_if_dynamic
use_per_token_if_dynamic
and
not
per_tensor_weights
and
not
per_tensor_weights
and
not
per_tensor_activations
and
not
per_tensor_activations
and
USE_ROWWISE_TORCH_SCALED_MM
and
(
USE_ROWWISE_TORCH_SCALED_MM
or
_use_aiter
)
):
):
# into this sector means use dynamic per-token-per-channel quant
# per-token scale quant for input matrix, every row(one token) have one scale factor
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
if
_use_aiter
:
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
# XQ -> input tensor, shape = (m, k)
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
# x_scale -> input scale tensor, shape = (m, 1)
# w_scale -> weight scale tensor, shape = (n ,1)
# dtype -> output dtype
output
=
gemm_a8w8_bpreshuffle
(
XQ
=
qinput
,
WQ
=
weight
,
x_scale
=
x_scale
,
w_scale
=
weight_scale
,
dtype
=
input
.
dtype
,
)
if
bias
is
not
None
:
output
+=
bias
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
)
else
:
# For now validated on ROCm platform
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
...
@@ -659,8 +682,9 @@ def apply_fp8_linear(
...
@@ -659,8 +682,9 @@ def apply_fp8_linear(
scale_b
=
weight_scale
.
t
(),
scale_b
=
weight_scale
.
t
(),
bias
=
bias
,
bias
=
bias
,
)
)
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
output_shape
)
return
_process_scaled_mm_output
(
output
,
input_2d
.
shape
,
output_shape
)
else
:
else
:
# Fallback for channelwise case, where we use unfused DQ
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
# due to limitations with scaled_mm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment