test_flex_attention.py 9.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""Integration tests for FlexAttention backend vs default backend"""

import pytest
import torch
from packaging import version

9
from tests.utils import set_random_seed
10
11
12
13
14
15
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
    create_vllm_config,
)
16
from vllm.v1.attention.backends.flex_attention import (
17
    BlockSparsityHint,
18
19
20
    FlexAttentionMetadataBuilder,
    physical_to_logical_mapping,
)
21

22
from ..models.utils import check_embeddings_close, check_logprobs_close
23
24
25

TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
26
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
27
28
29
30
31
32


@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
33
def test_flex_attention_vs_default_backend(vllm_runner):
34
35
36
    """Test that FlexAttention produces the same outputs as the default backend.

    This test compares the outputs from the FlexAttention backend with
37
    the default backend, ensuring they are similar when using the same seed.
38
39
40
    """
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"
    seed = 42
41
    max_tokens = 24
42
    num_logprobs = 5
43
44
45
46
47
48
49
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
    ]

    # Run with flex attention
50
    set_random_seed(seed)
51
52
53
54
55
56
57
58
59
60
61
    with vllm_runner(
        model_name,
        runner="generate",
        tensor_parallel_size=1,
        num_gpu_blocks_override=128,
        enforce_eager=True,
        attention_config={"backend": "FLEX_ATTENTION"},
    ) as llm_flex:
        output_flex = llm_flex.generate_greedy_logprobs(
            prompts, max_tokens, num_logprobs
        )
62
63

    # Run with default backend
64
    set_random_seed(seed)
65
66
67
68
69
70
71
72
73
74
75
    with vllm_runner(
        model_name,
        runner="generate",
        tensor_parallel_size=1,
        num_gpu_blocks_override=128,
        enforce_eager=True,
        gpu_memory_utilization=0.85,
    ) as llm_default:
        output_default = llm_default.generate_greedy_logprobs(
            prompts, max_tokens, num_logprobs
        )
76
77
78
79
80
81
82

    check_logprobs_close(
        outputs_0_lst=output_flex,
        outputs_1_lst=output_default,
        name_0="flex",
        name_1="default",
    )
83
84


85
86
87
88
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
89
def test_encoder_flex_attention_vs_default_backend(vllm_runner):
90
91
92
93
94
95
96
97
98
99
100
101
102
    """Test that FlexAttention produces the same outputs as the default backend.

    This test compares the outputs from the FlexAttention backend with
    the default backend for encoder models.
    """
    model_name = "BAAI/bge-base-en-v1.5"
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
    ]

    # Run with flex attention
103
104
105
106
107
108
109
110
111
112
    with vllm_runner(
        model_name,
        runner="pooling",
        dtype=torch.bfloat16,
        tensor_parallel_size=1,
        max_model_len=100,
        enforce_eager=True,
        attention_config={"backend": "FLEX_ATTENTION"},
    ) as llm_flex:
        flex_outputs = llm_flex.embed(prompts)
113
114

    # Run with default backend
115
116
117
118
119
120
121
122
    with vllm_runner(
        model_name,
        runner="pooling",
        dtype=torch.bfloat16,
        tensor_parallel_size=1,
        max_model_len=100,
        enforce_eager=True,
    ) as llm_default:
123
        default_outputs = llm_default.embed(prompts)
124
125
126
127
128
129
130
131
132
133

    check_embeddings_close(
        embeddings_0_lst=flex_outputs,
        embeddings_1_lst=default_outputs,
        name_0="flex",
        name_1="default",
        tol=1e-2,
    )


134
135
136
137
138
139
140
141
142
143
144
145
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
    reason="CUDA not available or PyTorch version < 2.7",
)
def test_block_mask_direct_vs_slow_path():
    """Test that direct path block mask is a superset of slow path.

    The direct path may include extra blocks for performance (over-estimation),
    but must include all blocks that the slow path determines are necessary.
    """
    device = torch.device("cuda")

146
147
148
    vllm_config = create_vllm_config(
        model_name="meta-llama/Meta-Llama-3-8B", block_size=16, max_model_len=1024
    )
149
150
151
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    # Use a mixed batch that will create groups spanning multiple sequences
152
153
154
    batch_spec = BatchSpec(
        seq_lens=[35, 64, 128, 256], query_lens=[33, 5, 32, 64], name="test_mixed_batch"
    )
155
156

    common_attn_metadata = create_common_attn_metadata(
157
158
        batch_spec, vllm_config.cache_config.block_size, device
    )
159

160
    builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
161

162
163
164
    metadata_direct = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
165
    builder.direct_build = False
166
167
168
    metadata_slow = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    assert metadata_direct.block_mask is not None
    assert metadata_slow.block_mask is not None

    # Extract block indices for comparison, B, H are the same
    direct_indices = metadata_direct.block_mask.kv_indices[0, 0]
    slow_indices = metadata_slow.block_mask.kv_indices[0, 0]
    direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0]
    slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0]

    # main test: every block needed by slow path must be in direct path
    num_groups = direct_num.shape[0]
    all_contained = True
    missing_details = []

    for group_idx in range(num_groups):
185
186
        direct_blocks = set(direct_indices[group_idx, : direct_num[group_idx]].tolist())
        slow_blocks = set(slow_indices[group_idx, : slow_num[group_idx]].tolist())
187
188
189
190
191

        missing_blocks = slow_blocks - direct_blocks
        if missing_blocks:
            all_contained = False
            missing_details.append(
192
193
                f"Group {group_idx}: missing {sorted(missing_blocks)}"
            )
194
195

    assert all_contained, (
196
197
198
        "Direct path is missing blocks required by slow path:\n"
        + "\n".join(missing_details)
    )
199
200


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def test_physical_to_logical_mapping_handles_reused_blocks():
    """Regression test: reused physical blocks map to the latest logical block.

    For sliding-window / hybrid attention layers, physical KV-cache blocks can be
    reused over time. The inverse mapping must therefore select the latest
    logical block index for a physical block id.
    """
    # Padding should not make physical block 0 look live.
    block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32)
    seq_lens = torch.tensor([1 * 16], dtype=torch.int32)  # only 1 block valid
    out = physical_to_logical_mapping(
        block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10
    )
    assert out[0, 0].item() == -1
    assert out[0, 6].item() == 0

    # If a physical block id appears multiple times (block reuse), mapping should
    # point to the latest logical block index.
    block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32)
    seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32)
    out2 = physical_to_logical_mapping(
        block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8
    )
    assert out2[0, 2].item() == 1


227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
@pytest.mark.skipif(
    not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
    reason="CUDA not available or PyTorch version < 2.9",
)
def test_block_sparsity_hint_prunes_blocks():
    """Test that BlockSparsityHint prunes KV blocks from the direct build path.

    Uses a hint that only keeps the diagonal (q_block == kv_block) to verify
    that off-diagonal blocks are excluded from the resulting BlockMask.
    """
    device = torch.device("cuda")

    vllm_config = create_vllm_config(
        model_name="facebook/opt-125m",
        block_size=16,
        max_model_len=1024,
    )
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    batch_spec = BatchSpec(
        seq_lens=[256],
        query_lens=[256],
        name="test_sparsity_hint",
    )

    common_attn_metadata = create_common_attn_metadata(
        batch_spec, vllm_config.cache_config.block_size, device
    )

    builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)

    metadata_no_hint = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
    metadata_no_hint.block_mask = metadata_no_hint._build_block_mask_direct()
    assert metadata_no_hint.block_mask.kv_num_blocks.max().item() > 1

    def diagonal_hint(q_block_idx, kv_block_idx, block_size):
        return q_block_idx == kv_block_idx

    metadata_with_hint = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
    metadata_with_hint.block_sparsity_hint = BlockSparsityHint(
        hint_fn=diagonal_hint,
    )
    metadata_with_hint.block_mask = metadata_with_hint._build_block_mask_direct()
    assert metadata_with_hint.block_mask.kv_num_blocks.max().item() <= 1


277
278
if __name__ == "__main__":
    pytest.main([__file__])