Unverified Commit 065ce815 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny cleanup fp4 gemm calls (#11537)

parent 8e51049f
...@@ -852,25 +852,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase): ...@@ -852,25 +852,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if enable_flashinfer_fp4_gemm: if enable_flashinfer_fp4_gemm:
w = layer.weight.T w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T w_scale_interleaved = layer.weight_scale_interleaved.T
if USE_CUTLASS_BACKEND_FOR_FP4_GEMM: out = fp4_gemm(
out = fp4_gemm( x_fp4,
x_fp4, w,
w, x_scale_interleaved,
x_scale_interleaved, w_scale_interleaved,
w_scale_interleaved, layer.alpha,
layer.alpha, output_dtype,
output_dtype, **(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
backend="cutlass", )
)
else:
out = fp4_gemm(
x_fp4,
w,
x_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment