Unverified Commit 4069db3f authored by roikoren755's avatar roikoren755 Committed by GitHub
Browse files

[Bugfix] Enable padded FP4 quantization (#25947)


Signed-off-by: default avatarRoi Koren <roik@nvidia.com>
parent 0d37450e
...@@ -1384,7 +1384,7 @@ def scaled_fp4_quant( ...@@ -1384,7 +1384,7 @@ def scaled_fp4_quant(
rounded_m = round_up(m, 128) rounded_m = round_up(m, 128)
scale_n = n // block_size scale_n = n // block_size
rounded_n = round_up(scale_n, 4) rounded_n = round_up(scale_n, 4)
output_scale = torch.empty( output_scale = torch.zeros(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32 (rounded_m, rounded_n // 4), device=device, dtype=torch.int32
) )
......
...@@ -386,8 +386,6 @@ def flashinfer_scaled_fp4_mm( ...@@ -386,8 +386,6 @@ def flashinfer_scaled_fp4_mm(
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.stride(-1) == 1 and b.stride(-1) == 1
assert a.shape[1] == b.shape[1] assert a.shape[1] == b.shape[1]
assert block_scale_a.shape[1] == a.shape[1] // 8
assert block_scale_b.shape[1] == b.shape[1] // 8
if backend == "cutlass": if backend == "cutlass":
block_scale_a = block_scale_a.view(torch.uint8) block_scale_a = block_scale_a.view(torch.uint8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment