Unverified Commit 4efe844a authored by Morpheus Guo's avatar Morpheus Guo Committed by GitHub
Browse files

enable aiter gemm_a8w8_bpreshuffle for ptpc gemm (#8555)

parent bde73ee4
......@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import get_bool_env_var, is_hip
__all__ = ["CompressedTensorsW8A8Fp8"]
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter.ops.shuffle import shuffle_weight
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
......@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else:
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
if _use_aiter:
layer.weight = Parameter(
shuffle_weight(weight, (16, 16)), requires_grad=False
)
else:
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
......
......@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
import aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
......@@ -642,25 +642,49 @@ def apply_fp8_linear(
use_per_token_if_dynamic
and not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
):
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias,
)
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
# into this sector means use dynamic per-token-per-channel quant
# per-token scale quant for input matrix, every row(one token) have one scale factor
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
if _use_aiter:
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
# XQ -> input tensor, shape = (m, k)
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
# x_scale -> input scale tensor, shape = (m, 1)
# w_scale -> weight scale tensor, shape = (n ,1)
# dtype -> output dtype
output = gemm_a8w8_bpreshuffle(
XQ=qinput,
WQ=weight,
x_scale=x_scale,
w_scale=weight_scale,
dtype=input.dtype,
)
if bias is not None:
output += bias
return _process_scaled_mm_output(
output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
)
else:
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias,
)
return _process_scaled_mm_output(
output, input_2d.shape, output_shape
)
else:
# Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment