Unverified Commit 47539cfd authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix mismatched nvfp4 gemm output shape (#29742)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 2afcec4d
......@@ -184,7 +184,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment