Unverified Commit 9a1d20a8 authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

[CI] Add warmup run in test_fusion_attn (#31183)


Signed-off-by: default avatarangelayi <yiangela7@gmail.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 309a8f66
...@@ -305,8 +305,12 @@ def test_attention_quant_pattern( ...@@ -305,8 +305,12 @@ def test_attention_quant_pattern(
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
dist_init, dist_init,
monkeypatch,
use_fresh_inductor_cache,
): ):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if backend == AttentionBackendEnum.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
): ):
...@@ -363,13 +367,15 @@ def test_attention_quant_pattern( ...@@ -363,13 +367,15 @@ def test_attention_quant_pattern(
vllm_config=vllm_config_unfused, vllm_config=vllm_config_unfused,
) )
model_unfused = model_unfused.to(device) model_unfused = model_unfused.to(device)
result_unfused_0 = model_unfused(q, k, v) # noqa: F841 HACK: See #131044
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
# Run model directly without fusion # Run model directly without fusion
# Still compile so query QuantFP8 has closer numerics # Still compile so query QuantFP8 has closer numerics
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v) compiled_unfused = torch.compile(model_unfused, fullgraph=True)
result_unfused = compiled_unfused(q, k, v)
# Run model with attn fusion enabled # Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
...@@ -399,24 +405,26 @@ def test_attention_quant_pattern( ...@@ -399,24 +405,26 @@ def test_attention_quant_pattern(
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# HACK: See https://github.com/vllm-project/vllm/issues/31044
result_fused_0 = model_fused(q, k, v) # noqa: F841
# Compile model with fusion enabled # Compile model with fusion enabled
model_compiled = torch.compile( compiled_fused = torch.compile(
model_fused, backend=test_backend, fullgraph=True model_fused, backend=test_backend, fullgraph=True
) )
assert model_compiled.attn._o_scale_float is None assert compiled_fused.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v) result_fused = compiled_fused(q, k, v)
if backend == AttentionBackendEnum.FLASHINFER: if backend == AttentionBackendEnum.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward # With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's # pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded # _o_scale_float, the 2nd round should reuse the loaded
# _o_scale_float # _o_scale_float
assert model_compiled.attn._o_scale_float is not None assert compiled_fused.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v) result_fused_2 = compiled_fused(q, k, v)
assert model_compiled.attn._o_scale_float is not None assert compiled_fused.attn._o_scale_float is not None
torch.testing.assert_close( torch.testing.assert_close(
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2 result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
...@@ -474,4 +482,4 @@ def test_attention_quant_pattern( ...@@ -474,4 +482,4 @@ def test_attention_quant_pattern(
) )
# Check that results are close # Check that results are close
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2) torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment