Unverified Commit 4d0a1541 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Remove NVFP4 scales assertions to fix load_format=dummy (#18861)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 77b6e74f
......@@ -585,9 +585,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
assert torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
"w1_weight_scale_2 must match w3_weight_scale_2")
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
layer.w13_weight_scale_2[:, 1]):
logger.warning_once(
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected.")
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
......
......@@ -22,7 +22,12 @@ def is_fp4_marlin_supported():
def fp4_marlin_process_scales(marlin_scales):
assert (marlin_scales >= 0).all()
if not (marlin_scales >= 0).all():
logger.warning_once(
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
"negative scales. Accuracy will likely be degraded. This is "
"because it changes the scales from FP8-S1E4M3 to a special "
"FP8-S0E5M3 format to speedup the dequantization.")
# convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment