Unverified Commit 8723b4f1 authored by Elfie Guo's avatar Elfie Guo Committed by GitHub
Browse files

Use FlashInfer's TRTLLM FP8 Blockscale GEMM (#8588)

parent 62f99e08
......@@ -161,16 +161,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
input_2d, block_size[1], column_major_scales=True
)
# TRTLLM requires column-major scaling factors
output = gemm_fp8_nt_groupwise(
q_input,
weight,
x_scale,
weight_scale,
scale_major_mode="K",
out_dtype=input_2d.dtype,
backend="trtllm",
)
if bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment