Unverified Commit 0727cc9e authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[BUGFIX] Fix `test_mla_backends.py`. Scale MLA projection weights to prevent...


[BUGFIX] Fix `test_mla_backends.py`. Scale MLA projection weights to prevent numerical instability  (#32529)
Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent a0490be8
......@@ -504,6 +504,14 @@ def test_backend_correctness(
W_UV = torch.randn(
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
)
# Scale weights to produce realistic magnitude outputs.
# Without scaling, projection output has std ~sqrt(kv_lora_rank) ≈ 22.6,
# causing extreme attention scores and numerical instability in LSE merging.
weight_scale = 1.0 / (kv_lora_rank**0.5)
W_UK = W_UK * weight_scale
W_UV = W_UV * weight_scale
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment