test_flex_attention.py 10.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""Integration tests for FlexAttention backend vs default backend"""

import pytest
import torch
from packaging import version

9
from tests.utils import set_random_seed
10
11
12
13
14
15
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
    create_vllm_config,
)
16
from vllm.v1.attention.backends.flex_attention import (
17
    BlockSparsityHint,
18
19
20
    FlexAttentionMetadataBuilder,
    physical_to_logical_mapping,
)
21

22
from ..models.utils import check_embeddings_close, check_logprobs_close
23
24
25

TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
26
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
def test_flex_attention_full_cudagraphs(vllm_runner):
    """Test the numerics for flex attention full cudagraphs support."""
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"
    seed = 42
    max_tokens = 24
    num_logprobs = 5
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
    ]

    # Run with flex attention eager
    set_random_seed(seed)
    with vllm_runner(
        model_name,
        runner="generate",
        tensor_parallel_size=1,
        num_gpu_blocks_override=128,
        enforce_eager=True,
        attention_config={"backend": "FLEX_ATTENTION"},
    ) as llm_flex:
        output_eager = llm_flex.generate_greedy_logprobs(
            prompts, max_tokens, num_logprobs
        )

    # Run with flex attention compiled
    set_random_seed(seed)
    with vllm_runner(
        model_name,
        runner="generate",
        tensor_parallel_size=1,
        num_gpu_blocks_override=128,
        enforce_eager=False,
        gpu_memory_utilization=0.85,
        attention_config={"backend": "FLEX_ATTENTION"},
    ) as llm_default:
        output_compile = llm_default.generate_greedy_logprobs(
            prompts, max_tokens, num_logprobs
        )

    check_logprobs_close(
        outputs_0_lst=output_eager,
        outputs_1_lst=output_compile,
        name_0="eager",
        name_1="compile",
    )


82
83
84
85
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
86
def test_flex_attention_vs_default_backend(vllm_runner):
87
88
89
    """Test that FlexAttention produces the same outputs as the default backend.

    This test compares the outputs from the FlexAttention backend with
90
    the default backend, ensuring they are similar when using the same seed.
91
92
93
    """
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"
    seed = 42
94
    max_tokens = 24
95
    num_logprobs = 5
96
97
98
99
100
101
102
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
    ]

    # Run with flex attention
103
    set_random_seed(seed)
104
105
106
107
108
109
110
111
112
113
114
    with vllm_runner(
        model_name,
        runner="generate",
        tensor_parallel_size=1,
        num_gpu_blocks_override=128,
        enforce_eager=True,
        attention_config={"backend": "FLEX_ATTENTION"},
    ) as llm_flex:
        output_flex = llm_flex.generate_greedy_logprobs(
            prompts, max_tokens, num_logprobs
        )
115
116

    # Run with default backend
117
    set_random_seed(seed)
118
119
120
121
122
123
124
125
126
127
128
    with vllm_runner(
        model_name,
        runner="generate",
        tensor_parallel_size=1,
        num_gpu_blocks_override=128,
        enforce_eager=True,
        gpu_memory_utilization=0.85,
    ) as llm_default:
        output_default = llm_default.generate_greedy_logprobs(
            prompts, max_tokens, num_logprobs
        )
129
130
131
132
133
134
135

    check_logprobs_close(
        outputs_0_lst=output_flex,
        outputs_1_lst=output_default,
        name_0="flex",
        name_1="default",
    )
136
137


138
139
140
141
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
142
def test_encoder_flex_attention_vs_default_backend(vllm_runner):
143
144
145
146
147
148
149
150
151
152
153
154
155
    """Test that FlexAttention produces the same outputs as the default backend.

    This test compares the outputs from the FlexAttention backend with
    the default backend for encoder models.
    """
    model_name = "BAAI/bge-base-en-v1.5"
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
    ]

    # Run with flex attention
156
157
158
159
160
161
162
163
164
165
    with vllm_runner(
        model_name,
        runner="pooling",
        dtype=torch.bfloat16,
        tensor_parallel_size=1,
        max_model_len=100,
        enforce_eager=True,
        attention_config={"backend": "FLEX_ATTENTION"},
    ) as llm_flex:
        flex_outputs = llm_flex.embed(prompts)
166
167

    # Run with default backend
168
169
170
171
172
173
174
175
    with vllm_runner(
        model_name,
        runner="pooling",
        dtype=torch.bfloat16,
        tensor_parallel_size=1,
        max_model_len=100,
        enforce_eager=True,
    ) as llm_default:
176
        default_outputs = llm_default.embed(prompts)
177
178
179
180
181
182
183
184
185
186

    check_embeddings_close(
        embeddings_0_lst=flex_outputs,
        embeddings_1_lst=default_outputs,
        name_0="flex",
        name_1="default",
        tol=1e-2,
    )


187
188
189
190
191
192
193
194
195
196
197
198
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
def test_block_mask_direct_vs_slow_path():
    """Test that direct path block mask is a superset of slow path.

    The direct path may include extra blocks for performance (over-estimation),
    but must include all blocks that the slow path determines are necessary.
    """
    device = torch.device("cuda")

199
200
201
    vllm_config = create_vllm_config(
        model_name="meta-llama/Meta-Llama-3-8B", block_size=16, max_model_len=1024
    )
202
203
204
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    # Use a mixed batch that will create groups spanning multiple sequences
205
206
207
    batch_spec = BatchSpec(
        seq_lens=[35, 64, 128, 256], query_lens=[33, 5, 32, 64], name="test_mixed_batch"
    )
208
209

    common_attn_metadata = create_common_attn_metadata(
210
211
        batch_spec, vllm_config.cache_config.block_size, device
    )
212

213
    builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
214

215
216
217
    metadata_direct = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
218
    builder.direct_build = False
219
220
221
    metadata_slow = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

    assert metadata_direct.block_mask is not None
    assert metadata_slow.block_mask is not None

    # Extract block indices for comparison, B, H are the same
    direct_indices = metadata_direct.block_mask.kv_indices[0, 0]
    slow_indices = metadata_slow.block_mask.kv_indices[0, 0]
    direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0]
    slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0]

    # main test: every block needed by slow path must be in direct path
    num_groups = direct_num.shape[0]
    all_contained = True
    missing_details = []

    for group_idx in range(num_groups):
238
239
        direct_blocks = set(direct_indices[group_idx, : direct_num[group_idx]].tolist())
        slow_blocks = set(slow_indices[group_idx, : slow_num[group_idx]].tolist())
240
241
242
243
244

        missing_blocks = slow_blocks - direct_blocks
        if missing_blocks:
            all_contained = False
            missing_details.append(
245
246
                f"Group {group_idx}: missing {sorted(missing_blocks)}"
            )
247
248

    assert all_contained, (
249
250
251
        "Direct path is missing blocks required by slow path:\n"
        + "\n".join(missing_details)
    )
252
253


254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def test_physical_to_logical_mapping_handles_reused_blocks():
    """Regression test: reused physical blocks map to the latest logical block.

    For sliding-window / hybrid attention layers, physical KV-cache blocks can be
    reused over time. The inverse mapping must therefore select the latest
    logical block index for a physical block id.
    """
    # Padding should not make physical block 0 look live.
    block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32)
    seq_lens = torch.tensor([1 * 16], dtype=torch.int32)  # only 1 block valid
    out = physical_to_logical_mapping(
        block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10
    )
    assert out[0, 0].item() == -1
    assert out[0, 6].item() == 0

    # If a physical block id appears multiple times (block reuse), mapping should
    # point to the latest logical block index.
    block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32)
    seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32)
    out2 = physical_to_logical_mapping(
        block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8
    )
    assert out2[0, 2].item() == 1


280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
    reason="CUDA not available or PyTorch version < 2.9",
)
def test_block_sparsity_hint_prunes_blocks():
    """Test that BlockSparsityHint prunes KV blocks from the direct build path.

    Uses a hint that only keeps the diagonal (q_block == kv_block) to verify
    that off-diagonal blocks are excluded from the resulting BlockMask.
    """
    device = torch.device("cuda")

    vllm_config = create_vllm_config(
        model_name="facebook/opt-125m",
        block_size=16,
        max_model_len=1024,
    )
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    batch_spec = BatchSpec(
        seq_lens=[256],
        query_lens=[256],
        name="test_sparsity_hint",
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec, vllm_config.cache_config.block_size, device
    )

    builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)

    metadata_no_hint = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
    metadata_no_hint.block_mask = metadata_no_hint._build_block_mask_direct()
    assert metadata_no_hint.block_mask.kv_num_blocks.max().item() > 1

    def diagonal_hint(q_block_idx, kv_block_idx, block_size):
        return q_block_idx == kv_block_idx

    metadata_with_hint = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
    metadata_with_hint.block_sparsity_hint = BlockSparsityHint(
        hint_fn=diagonal_hint,
    )
    metadata_with_hint.block_mask = metadata_with_hint._build_block_mask_direct()
    assert metadata_with_hint.block_mask.kv_num_blocks.max().item() <= 1


330
331
if __name__ == "__main__":
    pytest.main([__file__])