Unverified Commit 9a1d20a8 authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

[CI] Add warmup run in test_fusion_attn (#31183)


Signed-off-by: default avatarangelayi <yiangela7@gmail.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 309a8f66
......@@ -305,8 +305,12 @@ def test_attention_quant_pattern(
model_class: type[AttentionQuantPatternModel],
backend: AttentionBackendEnum,
dist_init,
monkeypatch,
use_fresh_inductor_cache,
):
"""Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
......@@ -363,13 +367,15 @@ def test_attention_quant_pattern(
vllm_config=vllm_config_unfused,
)
model_unfused = model_unfused.to(device)
result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
# Run model directly without fusion
# Still compile so query QuantFP8 has closer numerics
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v)
compiled_unfused = torch.compile(model_unfused, fullgraph=True)
result_unfused = compiled_unfused(q, k, v)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
......@@ -399,24 +405,26 @@ def test_attention_quant_pattern(
cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# HACK: See https://github.com/vllm-project/vllm/issues/31044
result_fused_0 = model_fused(q, k, v) # noqa: F841
# Compile model with fusion enabled
model_compiled = torch.compile(
compiled_fused = torch.compile(
model_fused, backend=test_backend, fullgraph=True
)
assert model_compiled.attn._o_scale_float is None
assert compiled_fused.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v)
result_fused = compiled_fused(q, k, v)
if backend == AttentionBackendEnum.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded
# _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert compiled_fused.attn._o_scale_float is not None
result_fused_2 = compiled_fused(q, k, v)
assert model_compiled.attn._o_scale_float is not None
assert compiled_fused.attn._o_scale_float is not None
torch.testing.assert_close(
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
......@@ -474,4 +482,4 @@ def test_attention_quant_pattern(
)
# Check that results are close
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment