test_flex_attention.py 6.74 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
"""Integration tests for FlexAttention backend vs default backend"""

import random

import numpy as np
import pytest
import torch
from packaging import version

12
13
14
15
16
17
18
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
    create_vllm_config,
)
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
19

20
from ..models.utils import check_embeddings_close, check_logprobs_close
21
22
23

TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
24
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


def set_seed(seed):
    """Set seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
40
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
41
42
43
    """Test that FlexAttention produces the same outputs as the default backend.

    This test compares the outputs from the FlexAttention backend with
44
    the default backend, ensuring they are similar when using the same seed.
45
46
47
    """
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"
    seed = 42
48
    max_tokens = 24
49
    num_logprobs = 5
50
51
52
53
54
55
56
57
58
59
60
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
    ]

    # Run with flex attention
    with monkeypatch.context() as m:
        m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")

        set_seed(seed)
61
62
63
64
65
66
67
        with vllm_runner(
            model_name,
            runner="generate",
            tensor_parallel_size=1,
            num_gpu_blocks_override=128,
            enforce_eager=True,
        ) as llm_flex:
68
            output_flex = llm_flex.generate_greedy_logprobs(
69
70
                prompts, max_tokens, num_logprobs
            )
71
72
73
74

    # Run with default backend
    with monkeypatch.context() as m:
        set_seed(seed)
75
76
77
78
79
80
81
82
        with vllm_runner(
            model_name,
            runner="generate",
            tensor_parallel_size=1,
            num_gpu_blocks_override=128,
            enforce_eager=True,
            gpu_memory_utilization=0.85,
        ) as llm_default:
83
            output_default = llm_default.generate_greedy_logprobs(
84
85
                prompts, max_tokens, num_logprobs
            )
86
87
88
89
90
91
92

    check_logprobs_close(
        outputs_0_lst=output_flex,
        outputs_1_lst=output_default,
        name_0="flex",
        name_1="default",
    )
93
94


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
    """Test that FlexAttention produces the same outputs as the default backend.

    This test compares the outputs from the FlexAttention backend with
    the default backend for encoder models.
    """
    model_name = "BAAI/bge-base-en-v1.5"
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
    ]

    # Run with flex attention
    with monkeypatch.context() as m:
        m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
115
116
117
118
119
120
121
122
        with vllm_runner(
            model_name,
            runner="pooling",
            dtype=torch.bfloat16,
            tensor_parallel_size=1,
            max_model_len=100,
            enforce_eager=True,
        ) as llm_flex:
123
124
125
            flex_outputs = llm_flex.embed(prompts)

    # Run with default backend
126
127
128
    with (
        monkeypatch.context() as m,
        vllm_runner(
129
130
131
132
133
134
            model_name,
            runner="pooling",
            dtype=torch.bfloat16,
            tensor_parallel_size=1,
            max_model_len=100,
            enforce_eager=True,
135
136
137
        ) as llm_default,
    ):
        default_outputs = llm_default.embed(prompts)
138
139
140
141
142
143
144
145
146
147

    check_embeddings_close(
        embeddings_0_lst=flex_outputs,
        embeddings_1_lst=default_outputs,
        name_0="flex",
        name_1="default",
        tol=1e-2,
    )


148
149
150
151
152
153
154
155
156
157
158
159
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
def test_block_mask_direct_vs_slow_path():
    """Test that direct path block mask is a superset of slow path.

    The direct path may include extra blocks for performance (over-estimation),
    but must include all blocks that the slow path determines are necessary.
    """
    device = torch.device("cuda")

160
161
162
    vllm_config = create_vllm_config(
        model_name="meta-llama/Meta-Llama-3-8B", block_size=16, max_model_len=1024
    )
163
164
165
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    # Use a mixed batch that will create groups spanning multiple sequences
166
167
168
    batch_spec = BatchSpec(
        seq_lens=[35, 64, 128, 256], query_lens=[33, 5, 32, 64], name="test_mixed_batch"
    )
169
170

    common_attn_metadata = create_common_attn_metadata(
171
172
        batch_spec, vllm_config.cache_config.block_size, device
    )
173

174
    builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
175

176
177
178
    metadata_direct = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
179
    builder.direct_build = False
180
181
182
    metadata_slow = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    assert metadata_direct.block_mask is not None
    assert metadata_slow.block_mask is not None

    # Extract block indices for comparison, B, H are the same
    direct_indices = metadata_direct.block_mask.kv_indices[0, 0]
    slow_indices = metadata_slow.block_mask.kv_indices[0, 0]
    direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0]
    slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0]

    # main test: every block needed by slow path must be in direct path
    num_groups = direct_num.shape[0]
    all_contained = True
    missing_details = []

    for group_idx in range(num_groups):
199
200
        direct_blocks = set(direct_indices[group_idx, : direct_num[group_idx]].tolist())
        slow_blocks = set(slow_indices[group_idx, : slow_num[group_idx]].tolist())
201
202
203
204
205

        missing_blocks = slow_blocks - direct_blocks
        if missing_blocks:
            all_contained = False
            missing_details.append(
206
207
                f"Group {group_idx}: missing {sorted(missing_blocks)}"
            )
208
209

    assert all_contained, (
210
211
212
        "Direct path is missing blocks required by slow path:\n"
        + "\n".join(missing_details)
    )
213
214


215
216
if __name__ == "__main__":
    pytest.main([__file__])