test_tree_attention.py 19 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math

6
import pytest
7
8
import torch

9
10
11
from tests.v1.attention.utils import (
    create_standard_kv_cache_spec,
    create_vllm_config,
12
    try_backend_includes_kv_cache_update,
13
    try_get_attention_backend,
14
)
15
from vllm.config import ParallelConfig, SpeculativeConfig
16
from vllm.platforms import current_platform
17
from vllm.v1.attention.backend import CommonAttentionMetadata
18
19
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
20

21
22
23
24
25
26
if not is_flash_attn_varlen_func_available():
    pytest.skip(
        "This test requires flash_attn_varlen_func, but it's not available.",
        allow_module_level=True,
    )

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# --------------------------------------------------------------------------- #
#  KV cache layout adaptation
# --------------------------------------------------------------------------- #
# Two KV cache layouts exist across backends:
#
#   Flash layout: (2, num_blocks, block_size, num_kv_heads, head_size)
#     - dim 0 separates key (index 0) and value (index 1)
#     - Used by: FLASH_ATTN, TREE_ATTN, ROCM_AITER_FA, ROCM_ATTN
#
#   Block layout: (num_blocks, 2, block_size, num_kv_heads, head_size)
#     - dim 1 separates key (index 0) and value (index 1)
#     - Used by: TRITON_ATTN
#
# The test creates KV caches in flash layout (the canonical format used by
# tree attention). When a reference backend needs block layout we transpose
# dims 0 and 1.
#
# Note: ROCM_ATTN uses flash layout for storage but its forward path calls
# PagedAttention.split_kv_cache which reinterprets the raw memory as paged
# layout (num_blocks, num_kv_heads, head_size//x, block_size, x). This is
# a view-level incompatibility, not a transpose - see the TODO in
# _get_available_reference_backends for details.
#
# TODO: Replace this mapping with a `KV_CACHE_LAYOUT` class attribute on each
# AttentionImpl so the layout is self-documented by the backend itself, e.g.:
#     class TritonAttentionImpl(AttentionImpl):
#         KV_CACHE_LAYOUT = "block"
# --------------------------------------------------------------------------- #

_BLOCK_KV_LAYOUT_BACKENDS = frozenset(
    {
        AttentionBackendEnum.TRITON_ATTN,
    }
)

# Backends whose do_kv_cache_update requires engine-level state (e.g.
# ForwardContext) that is not available in this test harness, but whose
# KV cache is flash layout and can be written with reshape_and_cache_flash.
# When a backend is listed here, forward_attention() bypasses
# do_kv_cache_update and writes directly to the cache.
_NEEDS_DIRECT_CACHE_UPDATE = frozenset(
    {
        AttentionBackendEnum.ROCM_AITER_FA,
    }
)

# Backends with known test-harness incompatibilities - see the TODOs
# inside _get_available_reference_backends for details.
_INCOMPATIBLE_REFERENCE_BACKENDS = frozenset(
    {
        AttentionBackendEnum.ROCM_AITER_FA,
        AttentionBackendEnum.ROCM_ATTN,
    }
)


def _adapt_kv_cache_for_backend(
    kv_cache: torch.Tensor,
    backend: AttentionBackendEnum,
) -> torch.Tensor:
    """Convert kv_cache from flash layout ``(2, num_blocks, ...)`` to block
    layout ``(num_blocks, 2, ...)`` if the backend requires it.  Returns the
    original tensor unchanged when no conversion is needed."""
    if backend in _BLOCK_KV_LAYOUT_BACKENDS:
        return kv_cache.transpose(0, 1).contiguous()
    return kv_cache


def _get_platform_default_backend() -> AttentionBackendEnum:
    """Ask the platform what backend it would auto-select at runtime."""
    from vllm.v1.attention.selector import AttentionSelectorConfig

    config = AttentionSelectorConfig(
        block_size=32,
        kv_cache_dtype="auto",
        use_mla=False,
        use_sparse=False,
        head_size=128,
        dtype=torch.bfloat16,
    )
    backend_path = current_platform.get_attn_backend_cls(
        selected_backend=None,
        attn_selector_config=config,
    )
    for backend in AttentionBackendEnum:
        try:
            if backend.get_path() == backend_path:
                return backend
        except ValueError:
            continue
    raise RuntimeError(
        f"Platform returned backend path '{backend_path}' "
        f"that doesn't match any AttentionBackendEnum member."
    )


def _get_available_reference_backends() -> list[AttentionBackendEnum]:
    """Collect all reference backends the current platform can run.

    On CUDA this is just FLASH_ATTN. On ROCm this includes the platform
    default plus every backend the hardware supports, so the test validates
    tree attention against all of them.
    """
    if current_platform.is_rocm():
        backends: list[AttentionBackendEnum] = []

        # 1. Whatever the platform would auto-select at runtime.
        default_backend = _get_platform_default_backend()
        if default_backend not in _INCOMPATIBLE_REFERENCE_BACKENDS:
            backends.append(default_backend)

        # 2. TRITON_ATTN - always available on ROCm.
        if AttentionBackendEnum.TRITON_ATTN not in backends:
            backends.append(AttentionBackendEnum.TRITON_ATTN)

        # TODO: Enable ROCM_ATTN. Its forward path uses
        # PagedAttention.split_kv_cache which reinterprets the raw
        # cache memory as paged layout:
        #   key:   (num_blocks, num_kv_heads, head_size//x, block_size, x)
        #   value: (num_blocks, num_kv_heads, head_size, block_size)
        # Tree attention writes prefix data in NHD flash layout, so the
        # same bytes produce completely different values when read in
        # paged format. Supporting ROCM_ATTN would require writing
        # prefix data via PagedAttention.write_to_paged_cache into a
        # separate paged-format KV cache.

        # TODO: Enable ROCM_AITER_FA. Its metadata builder reads head
        # counts from the model config at construction time and
        # allocates extend_workspace with those dimensions. The test
        # uses independent head count parameters (num_heads=2/4,
        # num_kv_heads=2) that don't match the model config
        # (Llama-3-8B: 32 q heads, 8 kv heads), causing a head count
        # mismatch in flash_attn_varlen_func during extend_forward.
        # Fixing this requires either matching test head counts to the
        # model config or decoupling the builder from model config
        # head geometry. The direct cache update path
        # (_NEEDS_DIRECT_CACHE_UPDATE) is already in place for when
        # this is resolved.

        return backends

    # CUDA: flash attention.
    return [AttentionBackendEnum.FLASH_ATTN]

171
172
173
174
175

class MockAttentionLayer(torch.nn.Module):
    _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
    _k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
    _v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
176
    layer_name = "mock_layer"
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


def forward_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    kv_cache: torch.Tensor,
    block_table: torch.Tensor,
    slot_mapping: torch.Tensor,
    seqlen_k: int,
193
    backend: AttentionBackendEnum,
194
    spec_token_tree: str | None = None,
195
196
    num_spec_tokens: int = 0,
) -> torch.Tensor:
197
198
199
200
201
202
203
    """Run a single attention forward pass through the given backend.

    ``kv_cache`` is expected in **flash layout**
    ``(2, num_blocks, block_size, num_kv_heads, head_size)``.
    It is automatically converted when the target backend needs a
    different layout.
    """
204
205
206
207
    batch_size, q_len, num_heads, dim_per_head = q.shape
    num_kv_heads = k.shape[-2]
    # Initialize the query and KV sequence lengths.
    query_start_loc = q_len * torch.arange(
208
209
        batch_size + 1, device=q.device, dtype=torch.int32
    )
210
211
    query_lens = torch.diff(query_start_loc)
    seq_lens = torch.full(
212
        (batch_size,),
213
214
215
216
217
        seqlen_k,
        device=q.device,
        dtype=torch.int32,
    )
    context_lens = seq_lens - query_lens
218
    max_seq_len = int(seq_lens.max())
219
220
221
    max_query_len = q_len
    num_actual_tokens = query_start_loc[-1]

222
    softmax_scale = q.shape[-1] ** (-0.5)
223
224
225
226
    layer = MockAttentionLayer()

    # Build common metadata.
    model_name = "meta-llama/Meta-Llama-3-8B"
227
    builder_cls, impl_cls = try_get_attention_backend(backend)
228
    vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
229
230
231
232
233
234
235
236
    if spec_token_tree is not None:
        # Create speculative config if token tree is specified.
        vllm_config.speculative_config = SpeculativeConfig(
            target_model_config=vllm_config.model_config,
            target_parallel_config=ParallelConfig(),
            model=model_name,
            method="eagle",
            num_speculative_tokens=num_spec_tokens,
237
238
            speculative_token_tree=spec_token_tree,
        )
239
240
241
242
243
244
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
    builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc.cpu(),
        seq_lens=seq_lens,
245
246
        _seq_lens_cpu=seq_lens.cpu(),
        _num_computed_tokens_cpu=context_lens.cpu(),
247
248
249
        num_reqs=batch_size,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
250
        max_seq_len=max_seq_len,
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        block_table_tensor=block_table,
        slot_mapping=slot_mapping,
    )

    # Build attention metadata.
    attn_metadata = builder.build(
        common_prefix_len=0,
        common_attn_metadata=common_attn_metadata,
    )

    # Initialize the backend implementation.
    instance = impl_cls(
        num_heads=num_heads,
        head_size=dim_per_head,
        scale=softmax_scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype="auto",
    )

272
273
274
    # Adapt KV cache layout for this backend.
    adapted_kv_cache = _adapt_kv_cache_for_backend(kv_cache, backend)

275
276
277
278
279
    # Run forward pass and return output.
    query = q.view(-1, num_heads, dim_per_head)
    key = k.view(-1, num_kv_heads, dim_per_head)
    value = v.view(-1, num_kv_heads, dim_per_head)
    output = torch.empty_like(query)
280
    if not try_backend_includes_kv_cache_update(backend):
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        if backend in _NEEDS_DIRECT_CACHE_UPDATE:
            # This backend's do_kv_cache_update requires engine-level
            # ForwardContext that isn't available in this test harness.
            # Write directly using reshape_and_cache_flash since the
            # KV cache layout is identical (flash layout, unbind on dim 0).
            key_cache, value_cache = adapted_kv_cache.unbind(0)
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                "auto",
                layer._k_scale,
                layer._v_scale,
            )
        else:
            instance.do_kv_cache_update(
                layer=layer,
                key=key,
                value=value,
                kv_cache=adapted_kv_cache,
                slot_mapping=attn_metadata.slot_mapping,
            )
305
306
307
308
309
    return instance.forward(
        layer=layer,
        query=query,
        key=key,
        value=value,
310
        kv_cache=adapted_kv_cache.clone(),
311
312
313
314
315
        attn_metadata=attn_metadata,
        output=output,
    )


316
317
318
319
320
321
322
323
@pytest.mark.parametrize(
    "reference_backend",
    _get_available_reference_backends(),
    ids=lambda b: b.name,
)
def test_tree_attn_correctness(
    reference_backend: AttentionBackendEnum,
) -> None:
324
325
326
327
328
329
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    device = "cuda"
    tree_attn_masks = {
        # Chain.
330
        "[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
331
332
333
334
335
336
337
338
339
340
            [
                [1, 0, 0, 0],
                [1, 1, 0, 0],
                [1, 1, 1, 0],
                [1, 1, 1, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
        # Tree.
341
        "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            [
                [1, 0, 0, 0, 0, 0, 0],
                [1, 1, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 0, 0, 0],
                [1, 1, 0, 1, 0, 0, 0],
                [1, 1, 0, 0, 1, 0, 0],
                [1, 0, 1, 0, 0, 1, 0],
                [1, 0, 1, 0, 0, 0, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
    }

    dim_per_head = 128
    num_kv_heads = 2
358
    block_size = 32
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    max_sequence_length = 8192
    randomize_blocks = True
    for batch_size in [1, 16, 32]:
        for num_heads in [2, 4]:
            for sequence_position in [16, 1024, 2048]:
                for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
                    # Assert that the number of heads is divisible
                    # by the number of KV heads.
                    assert num_heads % num_kv_heads == 0

                    # Initialize q, k, and v.
                    tree_size_q = tree_attn_mask.shape[0]
                    seqlen_k = sequence_position + tree_size_q
                    q = torch.randn(
                        (batch_size, tree_size_q, num_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    k = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    v = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )

388
389
390
                    # KV cache in flash layout - the canonical format for
                    # tree attention. forward_attention() handles conversion
                    # when needed.
391
392
393
394
395
396
397
398
399
400
401
402
403
                    assert max_sequence_length % block_size == 0
                    max_blocks_per_batch = max_sequence_length // block_size
                    kv_cache = torch.randn(
                        (
                            2,
                            batch_size * max_blocks_per_batch,
                            block_size,
                            num_kv_heads,
                            dim_per_head,
                        ),
                        device=q.device,
                        dtype=torch.bfloat16,
                    )
404
                    num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
405
406
407
408
409
410
411
412
413
414
415
416
417
                    block_table = torch.zeros(
                        (batch_size, max_blocks_per_batch),
                        device=q.device,
                        dtype=torch.int32,
                    )
                    block_ids = torch.arange(
                        0,
                        batch_size * num_alloc_blocks_per_batch,
                        device=q.device,
                        dtype=torch.int32,
                    )
                    if randomize_blocks:
                        # Randomize the block ids.
418
419
420
421
                        block_ids = block_ids[torch.randperm(block_ids.numel())]
                    block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
                        -1, num_alloc_blocks_per_batch
                    )
422

423
                    # Set up the slot mapping for the input KVs.
424
425
426
427
428
429
430
                    tree_positions = sequence_position + torch.arange(
                        0,
                        tree_size_q,
                        device=q.device,
                        dtype=torch.int64,
                    ).repeat(batch_size, 1)
                    tree_slot_mapping = _gen_slot_mapping(
431
432
                        tree_positions, block_table, block_size
                    )
433
434
435
436
437
438
439
440
441
442

                    # Compute attention for the tree.
                    tree_attn_output = forward_attention(
                        q=q,
                        k=k,
                        v=v,
                        kv_cache=kv_cache,
                        block_table=block_table,
                        slot_mapping=tree_slot_mapping,
                        seqlen_k=seqlen_k,
443
                        backend=AttentionBackendEnum.TREE_ATTN,
444
445
446
447
                        spec_token_tree=spec_token_tree,
                        num_spec_tokens=tree_size_q - 1,
                    ).view(batch_size, -1, num_heads, dim_per_head)

448
                    # Verify each branch against the reference backend.
449
450
451
                    for q_index in range(tree_size_q):
                        # Get the q, k, and v for the branch.
                        branch_mask = tree_attn_mask[q_index, :]
452
                        branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
453
454
455
456
457
458
459
460
461
462
463
464
465
                        q_len = branch_indices.shape[0]
                        q_branch = q[:, branch_indices]
                        k_branch = k[:, branch_indices]
                        v_branch = v[:, branch_indices]

                        # Setup slot mapping for the branch.
                        branch_positions = sequence_position + torch.arange(
                            0,
                            q_len,
                            device=q.device,
                            dtype=torch.int64,
                        ).repeat(batch_size, 1)
                        branch_slot_mapping = _gen_slot_mapping(
466
467
                            branch_positions, block_table, block_size
                        )
468

469
470
                        # Reference attention for this branch.
                        ref_output = forward_attention(
471
472
473
474
475
476
477
                            q=q_branch,
                            k=k_branch,
                            v=v_branch,
                            kv_cache=kv_cache,
                            block_table=block_table,
                            slot_mapping=branch_slot_mapping,
                            seqlen_k=sequence_position + q_len,
478
                            backend=reference_backend,
479
480
481
482
483
                        ).view(batch_size, -1, num_heads, dim_per_head)

                        # Compare the outputs.
                        assert torch.allclose(
                            tree_attn_output[:, branch_indices],
484
                            ref_output,
485
                            atol=7.81e-3,
486
487
                        ), (
                            f"outputs are not close for "
488
                            f"reference_backend: {reference_backend.name}, "
489
490
491
492
                            f"batch_size: {batch_size}, "
                            f"num_heads: {num_heads}, "
                            f"sequence_position: {sequence_position}, "
                            f"tree_attn_mask: {tree_attn_mask}, "
493
494
                            f"q_index: {q_index}."
                        )
495
496


497
498
499
def _gen_slot_mapping(
    positions: torch.Tensor, block_table: torch.Tensor, block_size: int
):
500
501
502
    block_indices = positions // block_size
    blocks = block_table.gather(dim=1, index=block_indices)
    return (blocks * block_size + positions % block_size).view(-1)