Unverified Commit 4efe844a authored by Morpheus Guo's avatar Morpheus Guo Committed by GitHub
Browse files

enable aiter gemm_a8w8_bpreshuffle for ptpc gemm (#8555)

parent bde73ee4
...@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import get_bool_env_var, is_hip
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter.ops.shuffle import shuffle_weight
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else: else:
weight_scale = layer.weight_scale.data weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False) if _use_aiter:
layer.weight = Parameter(
shuffle_weight(weight, (16, 16)), requires_grad=False
)
else:
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter # required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
......
...@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip ...@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter: if _use_aiter:
import aiter import aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128) aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
...@@ -642,25 +642,49 @@ def apply_fp8_linear( ...@@ -642,25 +642,49 @@ def apply_fp8_linear(
use_per_token_if_dynamic use_per_token_if_dynamic
and not per_tensor_weights and not per_tensor_weights
and not per_tensor_activations and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
): ):
# For now validated on ROCm platform # into this sector means use dynamic per-token-per-channel quant
# fp8 rowwise scaling in torch._scaled_mm is introduced in # per-token scale quant for input matrix, every row(one token) have one scale factor
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt # per-channel scale quant for weight matrix, every col(one channel) have one scale factor
# and ROCm 6.3, which only exists in torch 2.7 and above. if _use_aiter:
# For CUDA platform please validate if the # gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
# torch._scaled_mm support rowwise scaled GEMM # XQ -> input tensor, shape = (m, k)
# Fused GEMM_DQ Rowwise GEMM # WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
output = torch._scaled_mm( # x_scale -> input scale tensor, shape = (m, 1)
qinput, # w_scale -> weight scale tensor, shape = (n ,1)
weight, # dtype -> output dtype
out_dtype=input.dtype, output = gemm_a8w8_bpreshuffle(
scale_a=x_scale, XQ=qinput,
scale_b=weight_scale.t(), WQ=weight,
bias=bias, x_scale=x_scale,
) w_scale=weight_scale,
return _process_scaled_mm_output(output, input_2d.shape, output_shape) dtype=input.dtype,
)
if bias is not None:
output += bias
return _process_scaled_mm_output(
output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
)
else:
# For now validated on ROCm platform
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
# and ROCm 6.3, which only exists in torch 2.7 and above.
# For CUDA platform please validate if the
# torch._scaled_mm support rowwise scaled GEMM
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias,
)
return _process_scaled_mm_output(
output, input_2d.shape, output_shape
)
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
# due to limitations with scaled_mm # due to limitations with scaled_mm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment