"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "202351d5bf7a625852316516c0e1d00745bf88ab"
Unverified Commit 73105554 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Bugfix] Fix marlin nvfp4 rescaling (#37502)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
parent 96b5004b
...@@ -43,9 +43,9 @@ def _nvfp4_compute_scale_factor( ...@@ -43,9 +43,9 @@ def _nvfp4_compute_scale_factor(
ws_float = marlin_scales.float() * (2**7) ws_float = marlin_scales.float() * (2**7)
nonzero_mask = ws_float > 0 nonzero_mask = ws_float > 0
if nonzero_mask.any(): if nonzero_mask.any():
min_val = ws_float[nonzero_mask].min() max_val = ws_float[nonzero_mask].max()
if min_val < 2: if max_val < 448 * (2**7):
sf = (2 / min_val).log2().ceil().exp2() sf = (448 * (2**7) / max_val).log2().floor().exp2()
return sf.item() return sf.item()
return 1.0 return 1.0
...@@ -105,7 +105,9 @@ def nvfp4_marlin_process_scales( ...@@ -105,7 +105,9 @@ def nvfp4_marlin_process_scales(
if scale_factor > 1.0: if scale_factor > 1.0:
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half) marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 marlin_scales = marlin_scales * (2**7)
marlin_scales[marlin_scales < 2] = 0
marlin_scales = marlin_scales.view(torch.int16) << 1
marlin_scales = marlin_scales.view(torch.float8_e4m3fn) marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
marlin_scales = marlin_scales[:, 1::2].contiguous() marlin_scales = marlin_scales[:, 1::2].contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment