Unverified Commit 4c6675c4 authored by valarLip's avatar valarLip Committed by GitHub
Browse files

enable aiter fp8 blockscale quant (#7520)

parent e21aa1df
...@@ -42,7 +42,10 @@ _is_fp8_fnuz = is_fp8_fnuz() ...@@ -42,7 +42,10 @@ _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter: if _use_aiter:
from aiter import gemm_a8w8_blockscale_CK import aiter
from aiter import gemm_a8w8_blockscale_CK, get_hip_quant
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
...@@ -271,9 +274,7 @@ def aiter_w8a8_block_fp8_linear( ...@@ -271,9 +274,7 @@ def aiter_w8a8_block_fp8_linear(
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
input_2d, block_size[1], column_major_scales=False
)
output = gemm_a8w8_blockscale_CK( output = gemm_a8w8_blockscale_CK(
q_input, weight, x_scale, weight_scale, dtype=input.dtype q_input, weight, x_scale, weight_scale, dtype=input.dtype
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment