Unverified Commit 065ce815 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny cleanup fp4 gemm calls (#11537)

parent 8e51049f
......@@ -852,25 +852,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if enable_flashinfer_fp4_gemm:
w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T
if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
out = fp4_gemm(
x_fp4,
w,
x_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
backend="cutlass",
)
else:
out = fp4_gemm(
x_fp4,
w,
x_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
)
out = fp4_gemm(
x_fp4,
w,
x_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
**(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
)
if bias is not None:
out = out + bias
return out.view(*output_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment