test_pallas.py 3.17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
from unittest.mock import ANY, patch

import torch

from vllm.attention.backends.abstract import AttentionType
8
from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl,
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                                               PallasMetadata)


def test_ragged_paged_attention():
    # We verify that the kernel inputs such as sliding_window, etc. are passed
    # in from the model correctly.
    # The correctness of the paged attention kernel is tested in the kernel
    # library.
    num_heads = 4
    head_size = 128
    scale = 1.0
    num_kv_heads = 4
    sliding_window = 128
    logits_soft_cap = 50.0
    attn_impl = PallasAttentionBackendImpl(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
        sliding_window=sliding_window,
        kv_cache_dtype="auto",
        logits_soft_cap=logits_soft_cap,
        attn_type=AttentionType.DECODER,
    )

    class FakeAttentionLayer:
        _k_scale_float: float
        _v_scale_float: float

    layer = FakeAttentionLayer()
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0

    num_tokens = 16
    num_blocks = 1024
    block_size = 16
    query = torch.zeros(num_tokens, num_heads * head_size)
    key = torch.zeros(num_tokens, num_kv_heads * head_size)
    value = torch.zeros(num_tokens, num_kv_heads * head_size)
    kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
    slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
    max_num_reqs = 8
    max_num_blocks_per_req = 8
    block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
                               dtype=torch.int32)
    context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32)
    query_lens = [1] * max_num_reqs
    query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
                                                dtype=torch.int32),
                                   dim=0,
                                   dtype=torch.int32)
    num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
    attn_metadata = PallasMetadata(
        slot_mapping=slot_mapping,
        block_tables=block_tables,
        context_lens=context_lens,
        query_start_loc=query_start_loc,
        num_seqs=num_seqs,
    )

    with patch("torch.ops.xla.ragged_paged_attention"
               ) as mock_ragged_paged_attention:
        attn_impl.forward(
            layer=layer,
            query=query,
            key=key,
            value=value,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        mock_ragged_paged_attention.assert_called_once_with(
            ANY,  # query
            ANY,  # kv_cache
            ANY,  # context_lens
            ANY,  # block_tables
            ANY,  # query_start_loc
            ANY,  # num_seqs
88
89
90
            num_kv_pages_per_block=None,
            num_queries_per_block=None,
            vmem_limit_bytes=None,
91
92
93
94
95
            use_kernel=True,
            sm_scale=scale,
            sliding_window=sliding_window,
            soft_cap=logits_soft_cap,
        )