test_tree_attention.py 11.8 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math

6
import pytest
7
8
import torch

9
10
11
from tests.v1.attention.utils import (
    create_standard_kv_cache_spec,
    create_vllm_config,
12
    try_backend_includes_kv_cache_update,
13
    try_get_attention_backend,
14
)
15
from vllm.config import ParallelConfig, SpeculativeConfig
16
from vllm.v1.attention.backend import CommonAttentionMetadata
17
18
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
19

20
21
22
23
24
25
if not is_flash_attn_varlen_func_available():
    pytest.skip(
        "This test requires flash_attn_varlen_func, but it's not available.",
        allow_module_level=True,
    )

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

class MockAttentionLayer(torch.nn.Module):
    _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
    _k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
    _v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


def forward_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    kv_cache: torch.Tensor,
    block_table: torch.Tensor,
    slot_mapping: torch.Tensor,
    seqlen_k: int,
47
    backend: AttentionBackendEnum,
48
    spec_token_tree: str | None = None,
49
50
51
52
53
54
    num_spec_tokens: int = 0,
) -> torch.Tensor:
    batch_size, q_len, num_heads, dim_per_head = q.shape
    num_kv_heads = k.shape[-2]
    # Initialize the query and KV sequence lengths.
    query_start_loc = q_len * torch.arange(
55
56
        batch_size + 1, device=q.device, dtype=torch.int32
    )
57
58
    query_lens = torch.diff(query_start_loc)
    seq_lens = torch.full(
59
        (batch_size,),
60
61
62
63
64
        seqlen_k,
        device=q.device,
        dtype=torch.int32,
    )
    context_lens = seq_lens - query_lens
65
    max_seq_len = int(seq_lens.max())
66
67
68
    max_query_len = q_len
    num_actual_tokens = query_start_loc[-1]

69
    softmax_scale = q.shape[-1] ** (-0.5)
70
71
72
73
    layer = MockAttentionLayer()

    # Build common metadata.
    model_name = "meta-llama/Meta-Llama-3-8B"
74
    builder_cls, impl_cls = try_get_attention_backend(backend)
75
    vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
76
77
78
79
80
81
82
83
    if spec_token_tree is not None:
        # Create speculative config if token tree is specified.
        vllm_config.speculative_config = SpeculativeConfig(
            target_model_config=vllm_config.model_config,
            target_parallel_config=ParallelConfig(),
            model=model_name,
            method="eagle",
            num_speculative_tokens=num_spec_tokens,
84
85
            speculative_token_tree=spec_token_tree,
        )
86
87
88
89
90
91
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
    builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc.cpu(),
        seq_lens=seq_lens,
92
93
        _seq_lens_cpu=seq_lens.cpu(),
        _num_computed_tokens_cpu=context_lens.cpu(),
94
95
96
        num_reqs=batch_size,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
97
        max_seq_len=max_seq_len,
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        block_table_tensor=block_table,
        slot_mapping=slot_mapping,
    )

    # Build attention metadata.
    attn_metadata = builder.build(
        common_prefix_len=0,
        common_attn_metadata=common_attn_metadata,
    )

    # Initialize the backend implementation.
    instance = impl_cls(
        num_heads=num_heads,
        head_size=dim_per_head,
        scale=softmax_scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype="auto",
    )

    # Run forward pass and return output.
    query = q.view(-1, num_heads, dim_per_head)
    key = k.view(-1, num_kv_heads, dim_per_head)
    value = v.view(-1, num_kv_heads, dim_per_head)
    output = torch.empty_like(query)
124
125
126
127
128
129
130
131
    if not try_backend_includes_kv_cache_update(backend):
        instance.do_kv_cache_update(
            layer=layer,
            key=key,
            value=value,
            kv_cache=kv_cache,
            slot_mapping=attn_metadata.slot_mapping,
        )
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    return instance.forward(
        layer=layer,
        query=query,
        key=key,
        value=value,
        kv_cache=kv_cache.clone(),
        attn_metadata=attn_metadata,
        output=output,
    )


def test_tree_attn_correctness() -> None:
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    device = "cuda"
    tree_attn_masks = {
        # Chain.
150
        "[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
151
152
153
154
155
156
157
158
159
160
            [
                [1, 0, 0, 0],
                [1, 1, 0, 0],
                [1, 1, 1, 0],
                [1, 1, 1, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
        # Tree.
161
        "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            [
                [1, 0, 0, 0, 0, 0, 0],
                [1, 1, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 0, 0, 0],
                [1, 1, 0, 1, 0, 0, 0],
                [1, 1, 0, 0, 1, 0, 0],
                [1, 0, 1, 0, 0, 1, 0],
                [1, 0, 1, 0, 0, 0, 1],
            ],
            device=device,
            dtype=torch.int32,
        ),
    }

    dim_per_head = 128
    num_kv_heads = 2
178
    block_size = 32
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    max_sequence_length = 8192
    randomize_blocks = True
    for batch_size in [1, 16, 32]:
        for num_heads in [2, 4]:
            for sequence_position in [16, 1024, 2048]:
                for spec_token_tree, tree_attn_mask in tree_attn_masks.items():
                    # Assert that the number of heads is divisible
                    # by the number of KV heads.
                    assert num_heads % num_kv_heads == 0

                    # Initialize q, k, and v.
                    tree_size_q = tree_attn_mask.shape[0]
                    seqlen_k = sequence_position + tree_size_q
                    q = torch.randn(
                        (batch_size, tree_size_q, num_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    k = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )
                    v = torch.randn(
                        (batch_size, tree_size_q, num_kv_heads, dim_per_head),
                        device=device,
                        dtype=torch.bfloat16,
                    )

208
                    # Set up the block table and KV cache for paged KV.
209
210
211
212
213
214
215
216
217
218
219
220
221
                    assert max_sequence_length % block_size == 0
                    max_blocks_per_batch = max_sequence_length // block_size
                    kv_cache = torch.randn(
                        (
                            2,
                            batch_size * max_blocks_per_batch,
                            block_size,
                            num_kv_heads,
                            dim_per_head,
                        ),
                        device=q.device,
                        dtype=torch.bfloat16,
                    )
222
                    num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
223
224
225
226
227
228
229
230
231
232
233
234
235
                    block_table = torch.zeros(
                        (batch_size, max_blocks_per_batch),
                        device=q.device,
                        dtype=torch.int32,
                    )
                    block_ids = torch.arange(
                        0,
                        batch_size * num_alloc_blocks_per_batch,
                        device=q.device,
                        dtype=torch.int32,
                    )
                    if randomize_blocks:
                        # Randomize the block ids.
236
237
238
239
                        block_ids = block_ids[torch.randperm(block_ids.numel())]
                    block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
                        -1, num_alloc_blocks_per_batch
                    )
240

241
                    # Set up the slot mapping for the input KVs.
242
243
244
245
246
247
248
                    tree_positions = sequence_position + torch.arange(
                        0,
                        tree_size_q,
                        device=q.device,
                        dtype=torch.int64,
                    ).repeat(batch_size, 1)
                    tree_slot_mapping = _gen_slot_mapping(
249
250
                        tree_positions, block_table, block_size
                    )
251
252
253
254
255
256
257
258
259
260

                    # Compute attention for the tree.
                    tree_attn_output = forward_attention(
                        q=q,
                        k=k,
                        v=v,
                        kv_cache=kv_cache,
                        block_table=block_table,
                        slot_mapping=tree_slot_mapping,
                        seqlen_k=seqlen_k,
261
                        backend=AttentionBackendEnum.TREE_ATTN,
262
263
264
265
266
267
268
269
270
271
                        spec_token_tree=spec_token_tree,
                        num_spec_tokens=tree_size_q - 1,
                    ).view(batch_size, -1, num_heads, dim_per_head)

                    # Verify that the chain attention output for each
                    # branch of the tree (computed using FA3) matches
                    # the tree attention output.
                    for q_index in range(tree_size_q):
                        # Get the q, k, and v for the branch.
                        branch_mask = tree_attn_mask[q_index, :]
272
                        branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
273
274
275
276
277
278
279
280
281
282
283
284
285
                        q_len = branch_indices.shape[0]
                        q_branch = q[:, branch_indices]
                        k_branch = k[:, branch_indices]
                        v_branch = v[:, branch_indices]

                        # Setup slot mapping for the branch.
                        branch_positions = sequence_position + torch.arange(
                            0,
                            q_len,
                            device=q.device,
                            dtype=torch.int64,
                        ).repeat(batch_size, 1)
                        branch_slot_mapping = _gen_slot_mapping(
286
287
                            branch_positions, block_table, block_size
                        )
288
289
290
291
292
293
294
295
296
297

                        # Compute flash attention for the branch.
                        flash_attn_output = forward_attention(
                            q=q_branch,
                            k=k_branch,
                            v=v_branch,
                            kv_cache=kv_cache,
                            block_table=block_table,
                            slot_mapping=branch_slot_mapping,
                            seqlen_k=sequence_position + q_len,
298
                            backend=AttentionBackendEnum.FLASH_ATTN,
299
300
301
302
303
304
305
                        ).view(batch_size, -1, num_heads, dim_per_head)

                        # Compare the outputs.
                        assert torch.allclose(
                            tree_attn_output[:, branch_indices],
                            flash_attn_output,
                            atol=7.81e-3,
306
307
                        ), (
                            f"outputs are not close for "
308
309
310
311
                            f"batch_size: {batch_size}, "
                            f"num_heads: {num_heads}, "
                            f"sequence_position: {sequence_position}, "
                            f"tree_attn_mask: {tree_attn_mask}, "
312
313
                            f"q_index: {q_index}."
                        )
314
315


316
317
318
def _gen_slot_mapping(
    positions: torch.Tensor, block_table: torch.Tensor, block_size: int
):
319
320
321
    block_indices = positions // block_size
    blocks = block_table.gather(dim=1, index=block_indices)
    return (blocks * block_size + positions % block_size).view(-1)