Unverified Commit 6a47b730 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Remove contiguous before Flashinfer groupwise fp8 gemm (#6804)

parent c429919d
...@@ -166,11 +166,13 @@ def flashinfer_gemm_w8a8_block_fp8_linear( ...@@ -166,11 +166,13 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=False
) )
x_scale_input = x_scale.T.contiguous()
weight_scale_input = weight_scale.T.contiguous()
output = gemm_fp8_nt_groupwise( output = gemm_fp8_nt_groupwise(
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype q_input,
weight,
x_scale,
weight_scale,
scale_major_mode="K",
out_dtype=input_2d.dtype,
) )
if bias is not None: if bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment