Unverified Commit 92823069 authored by kk's avatar kk Committed by GitHub
Browse files

Fix ci test "test_eval_fp8_accuracy" failed (#5185)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
parent d2e507df
...@@ -243,8 +243,18 @@ def apply_fp8_linear( ...@@ -243,8 +243,18 @@ def apply_fp8_linear(
if _is_cuda: if _is_cuda:
qinput, x_scale = sglang_per_token_quant_fp8(input_2d) qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
else: else:
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
# final solution should be: 1. add support to per-tensor activation scaling.
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
if _is_hip and weight_scale.numel() == 1:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = ops.scaled_fp8_quant(
input_2d, input_scale, use_per_token_if_dynamic=use_per_token_if_dynamic input_2d,
input_scale,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
else:
qinput, x_scale = per_token_group_quant_fp8(
input_2d, group_size=input_2d.shape[1]
) )
if cutlass_fp8_supported: if cutlass_fp8_supported:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment