test_tree_attention.py 19 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math

6
import pytest
7
8
import torch

9
10
11
from tests.v1.attention.utils import (
    create_standard_kv_cache_spec,
    create_vllm_config,
12
    try_backend_includes_kv_cache_update,
13
    try_get_attention_backend,
14
)
15
from vllm.config import ParallelConfig, SpeculativeConfig
16
from vllm.platforms import current_platform
17
from vllm.v1.attention.backend import CommonAttentionMetadata
18
19
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
20

21
22
DEVICE_TYPE = current_platform.device_type

23
24
25
26
27
28
if not is_flash_attn_varlen_func_available():
    pytest.skip(
        "This test requires flash_attn_varlen_func, but it's not available.",
        allow_module_level=True,
    )

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# --------------------------------------------------------------------------- #
#  KV cache layout adaptation
# --------------------------------------------------------------------------- #
# Two KV cache layouts exist across backends:
#
#   Flash layout: (2, num_blocks, block_size, num_kv_heads, head_size)
#     - dim 0 separates key (index 0) and value (index 1)
#     - Used by: FLASH_ATTN, TREE_ATTN, ROCM_AITER_FA, ROCM_ATTN
#
#   Block layout: (num_blocks, 2, block_size, num_kv_heads, head_size)
#     - dim 1 separates key (index 0) and value (index 1)
#     - Used by: TRITON_ATTN
#
# The test creates KV caches in flash layout (the canonical format used by
# tree attention). When a reference backend needs block layout we transpose
# dims 0 and 1.
#
# Note: ROCM_ATTN uses flash layout for storage but its forward path calls
# PagedAttention.split_kv_cache which reinterprets the raw memory as paged
# layout (num_blocks, num_kv_heads, head_size//x, block_size, x). This is
# a view-level incompatibility, not a transpose - see the TODO in
# _get_available_reference_backends for details.
#
# TODO: Replace this mapping with a `KV_CACHE_LAYOUT` class attribute on each
# AttentionImpl so the layout is self-documented by the backend itself, e.g.:
#     class TritonAttentionImpl(AttentionImpl):
#         KV_CACHE_LAYOUT = "block"
# --------------------------------------------------------------------------- #

_BLOCK_KV_LAYOUT_BACKENDS = frozenset(
    {
        AttentionBackendEnum.TRITON_ATTN,
    }
)

# Backends whose do_kv_cache_update requires engine-level state (e.g.
# ForwardContext) that is not available in this test harness, but whose
# KV cache is flash layout and can be written with reshape_and_cache_flash.
# When a backend is listed here, forward_attention() bypasses
# do_kv_cache_update and writes directly to the cache.
_NEEDS_DIRECT_CACHE_UPDATE = frozenset(
    {
        AttentionBackendEnum.ROCM_AITER_FA,
    }
)

# Backends with known test-harness incompatibilities - see the TODOs
# inside _get_available_reference_backends for details.
_INCOMPATIBLE_REFERENCE_BACKENDS = frozenset(
    {
        AttentionBackendEnum.ROCM_AITER_FA,
        AttentionBackendEnum.ROCM_ATTN,
    }
)


def _adapt_kv_cache_for_backend(
    kv_cache: torch.Tensor,
    backend: AttentionBackendEnum,
) -> torch.Tensor:
    """Convert kv_cache from flash layout ``(2, num_blocks, ...)`` to block
    layout ``(num_blocks, 2, ...)`` if the backend requires it.  Returns the
    original tensor unchanged when no conversion is needed."""
    if backend in _BLOCK_KV_LAYOUT_BACKENDS:
        return kv_cache.transpose(0, 1).contiguous()
    return kv_cache


def _get_platform_default_backend() -> AttentionBackendEnum:
    """Ask the platform what backend it would auto-select at runtime."""
    from vllm.v1.attention.selector import AttentionSelectorConfig

    config = AttentionSelectorConfig(
        block_size=32,
        kv_cache_dtype="auto",
        use_mla=False,
        use_sparse=False,
        head_size=128,
        dtype=torch.bfloat16,
    )
    backend_path = current_platform.get_attn_backend_cls(
        selected_backend=None,
        attn_selector_config=config,
    )
    for backend in AttentionBackendEnum:
        try:
            if backend.get_path() == backend_path:
                return backend
        except ValueError:
            continue
    raise RuntimeError(
        f"Platform returned backend path '{backend_path}' "
        f"that doesn't match any AttentionBackendEnum member."
    )


def _get_available_reference_backends() -> list[AttentionBackendEnum]:
    """Collect all reference backends the current platform can run.

    On CUDA this is just FLASH_ATTN. On ROCm this includes the platform
    default plus every backend the hardware supports, so the test validates
    tree attention against all of them.
    """
    if current_platform.is_rocm():
        backends: list[AttentionBackendEnum] = []

        # 1. Whatever the platform would auto-select at runtime.
        default_backend = _get_platform_default_backend()
        if default_backend not in _INCOMPATIBLE_REFERENCE_BACKENDS:
            backends.append(default_backend)

        # 2. TRITON_ATTN - always available on ROCm.
        if AttentionBackendEnum.TRITON_ATTN not in backends:
            backends.append(AttentionBackendEnum.TRITON_ATTN)

        # TODO: Enable ROCM_ATTN. Its forward path uses
        # PagedAttention.split_kv_cache which reinterprets the raw
        # cache memory as paged layout:
        #   key:   (num_blocks, num_kv_heads, head_size//x, block_size, x)
        #   value: (num_blocks, num_kv_heads, head_size, block_size)
        # Tree attention writes prefix data in NHD flash layout, so the
        # same bytes produce completely different values when read in
        # paged format. Supporting ROCM_ATTN would require writing
        # prefix data via PagedAttention.write_to_paged_cache into a
        # separate paged-format KV cache.

        # TODO: Enable ROCM_AITER_FA. Its metadata builder reads head
        # counts from the model config at construction time and
        # allocates extend_workspace with those dimensions. The test
        # uses independent head count parameters (num_heads=2/4,
        # num_kv_heads=2) that don't match the model config
        # (Llama-3-8B: 32 q heads, 8 kv heads), causing a head count
        # mismatch in flash_attn_varlen_func during extend_forward.
        # Fixing this requires either matching test head counts to the
        # model config or decoupling the builder from model config
        # head geometry. The direct cache update path
        # (_NEEDS_DIRECT_CACHE_UPDATE) is already in place for when
        # this is resolved.

        return backends

    # CUDA: flash attention.
    return [AttentionBackendEnum.FLASH_ATTN]

173
174

class MockAttentionLayer(torch.nn.Module):
175
176
177
    _q_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
    _k_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
    _v_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
178
    layer_name = "mock_layer"
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


def forward_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    kv_cache: torch.Tensor,
    block_table: torch.Tensor,
    slot_mapping: torch.Tensor,
    seqlen_k: int,
195
    backend: AttentionBackendEnum,
196
    spec_token_tree: str | None = None,
197
198
    num_spec_tokens: int = 0,
) -> torch.Tensor:
199
200
201
202
203
204
205
    """Run a single attention forward pass through the given backend.

    ``kv_cache`` is expected in **flash layout**
    ``(2, num_blocks, block_size, num_kv_heads, head_size)``.
    It is automatically converted when the target backend needs a
    different layout.
    """
206
207
208
209
    batch_size, q_len, num_heads, dim_per_head = q.shape
    num_kv_heads = k.shape[-2]
    # Initialize the query and KV sequence lengths.
    query_start_loc = q_len * torch.arange(
210
211
        batch_size + 1, device=q.device, dtype=torch.int32
    )
212
213
    query_lens = torch.diff(query_start_loc)
    seq_lens = torch.full(
214
        (batch_size,),
215
216
217
218
219
        seqlen_k,
        device=q.device,
        dtype=torch.int32,
    )
    context_lens = seq_lens - query_lens
220
    max_seq_len = int(seq_lens.max())
221
222
223
    max_query_len = q_len
    num_actual_tokens = query_start_loc[-1]

224
    softmax_scale = q.shape[-1] ** (-0.5)
225
226
227
228
    layer = MockAttentionLayer()

    # Build common metadata.
    model_name = "meta-llama/Meta-Llama-3-8B"
229
    builder_cls, impl_cls = try_get_attention_backend(backend)
230
    vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
231
232
233
234
235
236
237
238
    if spec_token_tree is not None:
        # Create speculative config if token tree is specified.
        vllm_config.speculative_config = SpeculativeConfig(
            target_model_config=vllm_config.model_config,
            target_parallel_config=ParallelConfig(),
            model=model_name,
            method="eagle",
            num_speculative_tokens=num_spec_tokens,
239
240
            speculative_token_tree=spec_token_tree,
        )
241
242
243
244
245
246
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
    builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc.cpu(),
        seq_lens=seq_lens,
247
248
        _seq_lens_cpu=seq_lens.cpu(),
        _num_computed_tokens_cpu=context_lens.cpu(),
249
250
251
        num_reqs=batch_size,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
252
        max_seq_len=max_seq_len,
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        block_table_tensor=block_table,
        slot_mapping=slot_mapping,
    )

    # Build attention metadata.
    attn_metadata = builder.build(
        common_prefix_len=0,
        common_attn_metadata=common_attn_metadata,
    )

    # Initialize the backend implementation.
    instance = impl_cls(
        num_heads=num_heads,
        head_size=dim_per_head,
        scale=softmax_scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype="auto",
    )

274
275
276
    # Adapt KV cache layout for this backend.
    adapted_kv_cache = _adapt_kv_cache_for_backend(kv_cache, backend)

277
278
279
280
281
    # Run forward pass and return output.
    query = q.view(-1, num_heads, dim_per_head)
    key = k.view(-1, num_kv_heads, dim_per_head)
    value = v.view(-1, num_kv_heads, dim_per_head)
    output = torch.empty_like(query)
282
    if not try_backend_includes_kv_cache_update(backend):
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        if backend in _NEEDS_DIRECT_CACHE_UPDATE:
            # This backend's do_kv_cache_update requires engine-level
            # ForwardContext that isn't available in this test harness.
            # Write directly using reshape_and_cache_flash since the
            # KV cache layout is identical (flash layout, unbind on dim 0).
            key_cache, value_cache = adapted_kv_cache.unbind(0)
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                "auto",
                layer._k_scale,
                layer._v_scale,
            )
        else:
            instance.do_kv_cache_update(
                layer=layer,
                key=key,
                value=value,
                kv_cache=adapted_kv_cache,
                slot_mapping=attn_metadata.slot_mapping,
            )
307
308
309
310
311
    return instance.forward(
        layer=layer,
        query=query,
        key=key,
        value=value,
312
        kv_cache=adapted_kv_cache.clone(),
313
314
315
316
317
        attn_metadata=attn_metadata,
        output=output,
    )


318
319
320
321
322
323
324
325
@pytest.mark.parametrize(
    "reference_backend",
    _get_available_reference_backends(),
    ids=lambda b: b.name,
)
def test_tree_attn_correctness(
    reference_backend: AttentionBackendEnum,
) -> None:
326
327
328
329
330
331
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    device = "cuda"
    tree_attn_masks = {
        # Chain.
332
        "[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
333
334
335
336
337
338
339
340
341
342
            [
                [1, 0, 0, 0],
                [1, 1, 0, 0],
                [1, 1, 1, 0],
                [1, 1, 1, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
        # Tree.
343
        "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
            [
                [1, 0, 0, 0, 0, 0, 0],
                [1, 1, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 0, 0, 0],
                [1, 1, 0, 1, 0, 0, 0],
                [1, 1, 0, 0, 1, 0, 0],
                [1, 0, 1, 0, 0, 1, 0],
                [1, 0, 1, 0, 0, 0, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
    }

    dim_per_head = 128
    num_kv_heads = 2
360
    block_size = 32
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    max_sequence_length = 8192
    randomize_blocks = True
    for batch_size in [1, 16, 32]:
        for num_heads in [2, 4]:
            for sequence_position in [16, 1024, 2048]:
                for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
                    # Assert that the number of heads is divisible
                    # by the number of KV heads.
                    assert num_heads % num_kv_heads == 0

                    # Initialize q, k, and v.
                    tree_size_q = tree_attn_mask.shape[0]
                    seqlen_k = sequence_position + tree_size_q
                    q = torch.randn(
                        (batch_size, tree_size_q, num_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    k = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    v = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )

390
391
392
                    # KV cache in flash layout - the canonical format for
                    # tree attention. forward_attention() handles conversion
                    # when needed.
393
394
395
396
397
398
399
400
401
402
403
404
405
                    assert max_sequence_length % block_size == 0
                    max_blocks_per_batch = max_sequence_length // block_size
                    kv_cache = torch.randn(
                        (
                            2,
                            batch_size * max_blocks_per_batch,
                            block_size,
                            num_kv_heads,
                            dim_per_head,
                        ),
                        device=q.device,
                        dtype=torch.bfloat16,
                    )
406
                    num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
407
408
409
410
411
412
413
414
415
416
417
418
419
                    block_table = torch.zeros(
                        (batch_size, max_blocks_per_batch),
                        device=q.device,
                        dtype=torch.int32,
                    )
                    block_ids = torch.arange(
                        0,
                        batch_size * num_alloc_blocks_per_batch,
                        device=q.device,
                        dtype=torch.int32,
                    )
                    if randomize_blocks:
                        # Randomize the block ids.
420
421
422
423
                        block_ids = block_ids[torch.randperm(block_ids.numel())]
                    block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
                        -1, num_alloc_blocks_per_batch
                    )
424

425
                    # Set up the slot mapping for the input KVs.
426
427
428
429
430
431
432
                    tree_positions = sequence_position + torch.arange(
                        0,
                        tree_size_q,
                        device=q.device,
                        dtype=torch.int64,
                    ).repeat(batch_size, 1)
                    tree_slot_mapping = _gen_slot_mapping(
433
434
                        tree_positions, block_table, block_size
                    )
435
436
437
438
439
440
441
442
443
444

                    # Compute attention for the tree.
                    tree_attn_output = forward_attention(
                        q=q,
                        k=k,
                        v=v,
                        kv_cache=kv_cache,
                        block_table=block_table,
                        slot_mapping=tree_slot_mapping,
                        seqlen_k=seqlen_k,
445
                        backend=AttentionBackendEnum.TREE_ATTN,
446
447
448
449
                        spec_token_tree=spec_token_tree,
                        num_spec_tokens=tree_size_q - 1,
                    ).view(batch_size, -1, num_heads, dim_per_head)

450
                    # Verify each branch against the reference backend.
451
452
453
                    for q_index in range(tree_size_q):
                        # Get the q, k, and v for the branch.
                        branch_mask = tree_attn_mask[q_index, :]
454
                        branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
455
456
457
458
459
460
461
462
463
464
465
466
467
                        q_len = branch_indices.shape[0]
                        q_branch = q[:, branch_indices]
                        k_branch = k[:, branch_indices]
                        v_branch = v[:, branch_indices]

                        # Setup slot mapping for the branch.
                        branch_positions = sequence_position + torch.arange(
                            0,
                            q_len,
                            device=q.device,
                            dtype=torch.int64,
                        ).repeat(batch_size, 1)
                        branch_slot_mapping = _gen_slot_mapping(
468
469
                            branch_positions, block_table, block_size
                        )
470

471
472
                        # Reference attention for this branch.
                        ref_output = forward_attention(
473
474
475
476
477
478
479
                            q=q_branch,
                            k=k_branch,
                            v=v_branch,
                            kv_cache=kv_cache,
                            block_table=block_table,
                            slot_mapping=branch_slot_mapping,
                            seqlen_k=sequence_position + q_len,
480
                            backend=reference_backend,
481
482
483
484
485
                        ).view(batch_size, -1, num_heads, dim_per_head)

                        # Compare the outputs.
                        assert torch.allclose(
                            tree_attn_output[:, branch_indices],
486
                            ref_output,
487
                            atol=7.81e-3,
488
489
                        ), (
                            f"outputs are not close for "
490
                            f"reference_backend: {reference_backend.name}, "
491
492
493
494
                            f"batch_size: {batch_size}, "
                            f"num_heads: {num_heads}, "
                            f"sequence_position: {sequence_position}, "
                            f"tree_attn_mask: {tree_attn_mask}, "
495
496
                            f"q_index: {q_index}."
                        )
497
498


499
500
501
def _gen_slot_mapping(
    positions: torch.Tensor, block_table: torch.Tensor, block_size: int
):
502
503
504
    block_indices = positions // block_size
    blocks = block_table.gather(dim=1, index=block_indices)
    return (blocks * block_size + positions % block_size).view(-1)