test_tree_attention.py 19.1 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math

6
import pytest
7
8
import torch

9
10
11
from tests.v1.attention.utils import (
    create_standard_kv_cache_spec,
    create_vllm_config,
12
    try_backend_includes_kv_cache_update,
13
    try_get_attention_backend,
14
)
15
from vllm.config import ParallelConfig, SpeculativeConfig
16
from vllm.platforms import current_platform
17
from vllm.utils.torch_utils import set_random_seed
18
from vllm.v1.attention.backend import CommonAttentionMetadata
19
20
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
21

22
23
DEVICE_TYPE = current_platform.device_type

24
25
26
27
28
29
if not is_flash_attn_varlen_func_available():
    pytest.skip(
        "This test requires flash_attn_varlen_func, but it's not available.",
        allow_module_level=True,
    )

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# --------------------------------------------------------------------------- #
#  KV cache layout adaptation
# --------------------------------------------------------------------------- #
# Two KV cache layouts exist across backends:
#
#   Flash layout: (2, num_blocks, block_size, num_kv_heads, head_size)
#     - dim 0 separates key (index 0) and value (index 1)
#     - Used by: FLASH_ATTN, TREE_ATTN, ROCM_AITER_FA, ROCM_ATTN
#
#   Block layout: (num_blocks, 2, block_size, num_kv_heads, head_size)
#     - dim 1 separates key (index 0) and value (index 1)
#     - Used by: TRITON_ATTN
#
# The test creates KV caches in flash layout (the canonical format used by
# tree attention). When a reference backend needs block layout we transpose
# dims 0 and 1.
#
# Note: ROCM_ATTN uses flash layout for storage but its forward path calls
# PagedAttention.split_kv_cache which reinterprets the raw memory as paged
# layout (num_blocks, num_kv_heads, head_size//x, block_size, x). This is
# a view-level incompatibility, not a transpose - see the TODO in
# _get_available_reference_backends for details.
#
# TODO: Replace this mapping with a `KV_CACHE_LAYOUT` class attribute on each
# AttentionImpl so the layout is self-documented by the backend itself, e.g.:
#     class TritonAttentionImpl(AttentionImpl):
#         KV_CACHE_LAYOUT = "block"
# --------------------------------------------------------------------------- #

_BLOCK_KV_LAYOUT_BACKENDS = frozenset(
    {
        AttentionBackendEnum.TRITON_ATTN,
    }
)

# Backends whose do_kv_cache_update requires engine-level state (e.g.
# ForwardContext) that is not available in this test harness, but whose
# KV cache is flash layout and can be written with reshape_and_cache_flash.
# When a backend is listed here, forward_attention() bypasses
# do_kv_cache_update and writes directly to the cache.
_NEEDS_DIRECT_CACHE_UPDATE = frozenset(
    {
        AttentionBackendEnum.ROCM_AITER_FA,
    }
)

# Backends with known test-harness incompatibilities - see the TODOs
# inside _get_available_reference_backends for details.
_INCOMPATIBLE_REFERENCE_BACKENDS = frozenset(
    {
        AttentionBackendEnum.ROCM_AITER_FA,
        AttentionBackendEnum.ROCM_ATTN,
    }
)


def _adapt_kv_cache_for_backend(
    kv_cache: torch.Tensor,
    backend: AttentionBackendEnum,
) -> torch.Tensor:
    """Convert kv_cache from flash layout ``(2, num_blocks, ...)`` to block
    layout ``(num_blocks, 2, ...)`` if the backend requires it.  Returns the
    original tensor unchanged when no conversion is needed."""
    if backend in _BLOCK_KV_LAYOUT_BACKENDS:
        return kv_cache.transpose(0, 1).contiguous()
    return kv_cache


def _get_platform_default_backend() -> AttentionBackendEnum:
    """Ask the platform what backend it would auto-select at runtime."""
    from vllm.v1.attention.selector import AttentionSelectorConfig

    config = AttentionSelectorConfig(
        block_size=32,
        kv_cache_dtype="auto",
        use_mla=False,
        use_sparse=False,
        head_size=128,
        dtype=torch.bfloat16,
    )
    backend_path = current_platform.get_attn_backend_cls(
        selected_backend=None,
        attn_selector_config=config,
    )
    for backend in AttentionBackendEnum:
        try:
            if backend.get_path() == backend_path:
                return backend
        except ValueError:
            continue
    raise RuntimeError(
        f"Platform returned backend path '{backend_path}' "
        f"that doesn't match any AttentionBackendEnum member."
    )


def _get_available_reference_backends() -> list[AttentionBackendEnum]:
    """Collect all reference backends the current platform can run.

    On CUDA this is just FLASH_ATTN. On ROCm this includes the platform
    default plus every backend the hardware supports, so the test validates
    tree attention against all of them.
    """
    if current_platform.is_rocm():
        backends: list[AttentionBackendEnum] = []

        # 1. Whatever the platform would auto-select at runtime.
        default_backend = _get_platform_default_backend()
        if default_backend not in _INCOMPATIBLE_REFERENCE_BACKENDS:
            backends.append(default_backend)

        # 2. TRITON_ATTN - always available on ROCm.
        if AttentionBackendEnum.TRITON_ATTN not in backends:
            backends.append(AttentionBackendEnum.TRITON_ATTN)

        # TODO: Enable ROCM_ATTN. Its forward path uses
        # PagedAttention.split_kv_cache which reinterprets the raw
        # cache memory as paged layout:
        #   key:   (num_blocks, num_kv_heads, head_size//x, block_size, x)
        #   value: (num_blocks, num_kv_heads, head_size, block_size)
        # Tree attention writes prefix data in NHD flash layout, so the
        # same bytes produce completely different values when read in
        # paged format. Supporting ROCM_ATTN would require writing
        # prefix data via PagedAttention.write_to_paged_cache into a
        # separate paged-format KV cache.

        # TODO: Enable ROCM_AITER_FA. Its metadata builder reads head
        # counts from the model config at construction time and
        # allocates extend_workspace with those dimensions. The test
        # uses independent head count parameters (num_heads=2/4,
        # num_kv_heads=2) that don't match the model config
        # (Llama-3-8B: 32 q heads, 8 kv heads), causing a head count
        # mismatch in flash_attn_varlen_func during extend_forward.
        # Fixing this requires either matching test head counts to the
        # model config or decoupling the builder from model config
        # head geometry. The direct cache update path
        # (_NEEDS_DIRECT_CACHE_UPDATE) is already in place for when
        # this is resolved.

        return backends

    # CUDA: flash attention.
    return [AttentionBackendEnum.FLASH_ATTN]

174
175

class MockAttentionLayer(torch.nn.Module):
176
177
178
    _q_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
    _k_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
    _v_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
179
    layer_name = "mock_layer"
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


def forward_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    kv_cache: torch.Tensor,
    block_table: torch.Tensor,
    slot_mapping: torch.Tensor,
    seqlen_k: int,
196
    backend: AttentionBackendEnum,
197
    spec_token_tree: str | None = None,
198
199
    num_spec_tokens: int = 0,
) -> torch.Tensor:
200
201
202
203
204
205
206
    """Run a single attention forward pass through the given backend.

    ``kv_cache`` is expected in **flash layout**
    ``(2, num_blocks, block_size, num_kv_heads, head_size)``.
    It is automatically converted when the target backend needs a
    different layout.
    """
207
208
209
210
    batch_size, q_len, num_heads, dim_per_head = q.shape
    num_kv_heads = k.shape[-2]
    # Initialize the query and KV sequence lengths.
    query_start_loc = q_len * torch.arange(
211
212
        batch_size + 1, device=q.device, dtype=torch.int32
    )
213
214
    query_lens = torch.diff(query_start_loc)
    seq_lens = torch.full(
215
        (batch_size,),
216
217
218
219
220
        seqlen_k,
        device=q.device,
        dtype=torch.int32,
    )
    context_lens = seq_lens - query_lens
221
    max_seq_len = int(seq_lens.max())
222
223
224
    max_query_len = q_len
    num_actual_tokens = query_start_loc[-1]

225
    softmax_scale = q.shape[-1] ** (-0.5)
226
227
228
229
    layer = MockAttentionLayer()

    # Build common metadata.
    model_name = "meta-llama/Meta-Llama-3-8B"
230
    builder_cls, impl_cls = try_get_attention_backend(backend)
231
    vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
232
233
234
235
236
237
238
239
    if spec_token_tree is not None:
        # Create speculative config if token tree is specified.
        vllm_config.speculative_config = SpeculativeConfig(
            target_model_config=vllm_config.model_config,
            target_parallel_config=ParallelConfig(),
            model=model_name,
            method="eagle",
            num_speculative_tokens=num_spec_tokens,
240
241
            speculative_token_tree=spec_token_tree,
        )
242
243
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
    builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
244
    seq_lens_cpu = seq_lens.cpu()
245
246
247
248
    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc.cpu(),
        seq_lens=seq_lens,
249
250
        seq_lens_cpu_upper_bound=seq_lens_cpu,
        _seq_lens_cpu=seq_lens_cpu,
251
        _num_computed_tokens_cpu=context_lens.cpu(),
252
253
254
        num_reqs=batch_size,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
255
        max_seq_len=max_seq_len,
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        block_table_tensor=block_table,
        slot_mapping=slot_mapping,
    )

    # Build attention metadata.
    attn_metadata = builder.build(
        common_prefix_len=0,
        common_attn_metadata=common_attn_metadata,
    )

    # Initialize the backend implementation.
    instance = impl_cls(
        num_heads=num_heads,
        head_size=dim_per_head,
        scale=softmax_scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype="auto",
    )

277
278
279
    # Adapt KV cache layout for this backend.
    adapted_kv_cache = _adapt_kv_cache_for_backend(kv_cache, backend)

280
281
282
283
284
    # Run forward pass and return output.
    query = q.view(-1, num_heads, dim_per_head)
    key = k.view(-1, num_kv_heads, dim_per_head)
    value = v.view(-1, num_kv_heads, dim_per_head)
    output = torch.empty_like(query)
285
    if not try_backend_includes_kv_cache_update(backend):
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        if backend in _NEEDS_DIRECT_CACHE_UPDATE:
            # This backend's do_kv_cache_update requires engine-level
            # ForwardContext that isn't available in this test harness.
            # Write directly using reshape_and_cache_flash since the
            # KV cache layout is identical (flash layout, unbind on dim 0).
            key_cache, value_cache = adapted_kv_cache.unbind(0)
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                "auto",
                layer._k_scale,
                layer._v_scale,
            )
        else:
            instance.do_kv_cache_update(
                layer=layer,
                key=key,
                value=value,
                kv_cache=adapted_kv_cache,
                slot_mapping=attn_metadata.slot_mapping,
            )
310
311
312
313
314
    return instance.forward(
        layer=layer,
        query=query,
        key=key,
        value=value,
315
        kv_cache=adapted_kv_cache.clone(),
316
317
318
319
320
        attn_metadata=attn_metadata,
        output=output,
    )


321
322
323
324
325
326
327
328
@pytest.mark.parametrize(
    "reference_backend",
    _get_available_reference_backends(),
    ids=lambda b: b.name,
)
def test_tree_attn_correctness(
    reference_backend: AttentionBackendEnum,
) -> None:
329
    set_random_seed(42)
330
331
332
333

    device = "cuda"
    tree_attn_masks = {
        # Chain.
334
        "[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
335
336
337
338
339
340
341
342
343
344
            [
                [1, 0, 0, 0],
                [1, 1, 0, 0],
                [1, 1, 1, 0],
                [1, 1, 1, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
        # Tree.
345
        "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
            [
                [1, 0, 0, 0, 0, 0, 0],
                [1, 1, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 0, 0, 0],
                [1, 1, 0, 1, 0, 0, 0],
                [1, 1, 0, 0, 1, 0, 0],
                [1, 0, 1, 0, 0, 1, 0],
                [1, 0, 1, 0, 0, 0, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
    }

    dim_per_head = 128
    num_kv_heads = 2
362
    block_size = 32
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    max_sequence_length = 8192
    randomize_blocks = True
    for batch_size in [1, 16, 32]:
        for num_heads in [2, 4]:
            for sequence_position in [16, 1024, 2048]:
                for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
                    # Assert that the number of heads is divisible
                    # by the number of KV heads.
                    assert num_heads % num_kv_heads == 0

                    # Initialize q, k, and v.
                    tree_size_q = tree_attn_mask.shape[0]
                    seqlen_k = sequence_position + tree_size_q
                    q = torch.randn(
                        (batch_size, tree_size_q, num_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    k = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    v = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )

392
393
394
                    # KV cache in flash layout - the canonical format for
                    # tree attention. forward_attention() handles conversion
                    # when needed.
395
396
397
398
399
400
401
402
403
404
405
406
407
                    assert max_sequence_length % block_size == 0
                    max_blocks_per_batch = max_sequence_length // block_size
                    kv_cache = torch.randn(
                        (
                            2,
                            batch_size * max_blocks_per_batch,
                            block_size,
                            num_kv_heads,
                            dim_per_head,
                        ),
                        device=q.device,
                        dtype=torch.bfloat16,
                    )
408
                    num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
409
410
411
412
413
414
415
416
417
418
419
420
421
                    block_table = torch.zeros(
                        (batch_size, max_blocks_per_batch),
                        device=q.device,
                        dtype=torch.int32,
                    )
                    block_ids = torch.arange(
                        0,
                        batch_size * num_alloc_blocks_per_batch,
                        device=q.device,
                        dtype=torch.int32,
                    )
                    if randomize_blocks:
                        # Randomize the block ids.
422
423
424
425
                        block_ids = block_ids[torch.randperm(block_ids.numel())]
                    block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
                        -1, num_alloc_blocks_per_batch
                    )
426

427
                    # Set up the slot mapping for the input KVs.
428
429
430
431
432
433
434
                    tree_positions = sequence_position + torch.arange(
                        0,
                        tree_size_q,
                        device=q.device,
                        dtype=torch.int64,
                    ).repeat(batch_size, 1)
                    tree_slot_mapping = _gen_slot_mapping(
435
436
                        tree_positions, block_table, block_size
                    )
437
438
439
440
441
442
443
444
445
446

                    # Compute attention for the tree.
                    tree_attn_output = forward_attention(
                        q=q,
                        k=k,
                        v=v,
                        kv_cache=kv_cache,
                        block_table=block_table,
                        slot_mapping=tree_slot_mapping,
                        seqlen_k=seqlen_k,
447
                        backend=AttentionBackendEnum.TREE_ATTN,
448
449
450
451
                        spec_token_tree=spec_token_tree,
                        num_spec_tokens=tree_size_q - 1,
                    ).view(batch_size, -1, num_heads, dim_per_head)

452
                    # Verify each branch against the reference backend.
453
454
455
                    for q_index in range(tree_size_q):
                        # Get the q, k, and v for the branch.
                        branch_mask = tree_attn_mask[q_index, :]
456
                        branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
457
458
459
460
461
462
463
464
465
466
467
468
469
                        q_len = branch_indices.shape[0]
                        q_branch = q[:, branch_indices]
                        k_branch = k[:, branch_indices]
                        v_branch = v[:, branch_indices]

                        # Setup slot mapping for the branch.
                        branch_positions = sequence_position + torch.arange(
                            0,
                            q_len,
                            device=q.device,
                            dtype=torch.int64,
                        ).repeat(batch_size, 1)
                        branch_slot_mapping = _gen_slot_mapping(
470
471
                            branch_positions, block_table, block_size
                        )
472

473
474
                        # Reference attention for this branch.
                        ref_output = forward_attention(
475
476
477
478
479
480
481
                            q=q_branch,
                            k=k_branch,
                            v=v_branch,
                            kv_cache=kv_cache,
                            block_table=block_table,
                            slot_mapping=branch_slot_mapping,
                            seqlen_k=sequence_position + q_len,
482
                            backend=reference_backend,
483
484
485
486
487
                        ).view(batch_size, -1, num_heads, dim_per_head)

                        # Compare the outputs.
                        assert torch.allclose(
                            tree_attn_output[:, branch_indices],
488
                            ref_output,
489
                            atol=7.81e-3,
490
491
                        ), (
                            f"outputs are not close for "
492
                            f"reference_backend: {reference_backend.name}, "
493
494
495
496
                            f"batch_size: {batch_size}, "
                            f"num_heads: {num_heads}, "
                            f"sequence_position: {sequence_position}, "
                            f"tree_attn_mask: {tree_attn_mask}, "
497
498
                            f"q_index: {q_index}."
                        )
499
500


501
502
503
def _gen_slot_mapping(
    positions: torch.Tensor, block_table: torch.Tensor, block_size: int
):
504
505
506
    block_indices = positions // block_size
    blocks = block_table.gather(dim=1, index=block_indices)
    return (blocks * block_size + positions % block_size).view(-1)