Unverified Commit 8723b4f1 authored by Elfie Guo's avatar Elfie Guo Committed by GitHub
Browse files

Use FlashInfer's TRTLLM FP8 Blockscale GEMM (#8588)

parent 62f99e08
...@@ -161,16 +161,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear( ...@@ -161,16 +161,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = sglang_per_token_group_quant_fp8( q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=True
) )
# TRTLLM requires column-major scaling factors
output = gemm_fp8_nt_groupwise( output = gemm_fp8_nt_groupwise(
q_input, q_input,
weight, weight,
x_scale, x_scale,
weight_scale, weight_scale,
scale_major_mode="K",
out_dtype=input_2d.dtype, out_dtype=input_2d.dtype,
backend="trtllm",
) )
if bias is not None: if bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment