Unverified Commit 4254aeb5 authored by Carl Y's avatar Carl Y Committed by GitHub
Browse files

[fix] flaky test_mla_attn_quant_fusion.py (#40530)


Signed-off-by: default avatarCarl You <4531192+carlyou@users.noreply.github.com>
parent aad88f84
...@@ -83,10 +83,6 @@ class MLAAttentionQuantPatternModel(torch.nn.Module): ...@@ -83,10 +83,6 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
# Create kv_b_proj (ColumnParallelLinear) on device.
# Reuse weights from prior model instance when available, because
# ColumnParallelLinear may get NaN from recycled CUDA memory after
# torch.compile runs in the same process.
kv_b_proj = ColumnParallelLinear( kv_b_proj = ColumnParallelLinear(
input_size=kv_lora_rank, input_size=kv_lora_rank,
output_size=num_heads * (qk_nope_head_dim + v_head_dim), output_size=num_heads * (qk_nope_head_dim + v_head_dim),
...@@ -96,8 +92,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module): ...@@ -96,8 +92,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
kv_b_proj_weight = kwargs.get("kv_b_proj_weight") kv_b_proj_weight = kwargs.get("kv_b_proj_weight")
if kv_b_proj_weight is not None: if kv_b_proj_weight is not None:
kv_b_proj.weight.data.copy_(kv_b_proj_weight) kv_b_proj.weight.data.copy_(kv_b_proj_weight)
elif kv_b_proj.weight.data.isnan().any(): else:
# Sanitize NaN from recycled CUDA memory
kv_b_proj.weight.data.normal_() kv_b_proj.weight.data.normal_()
# Create MLAAttention # Create MLAAttention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment