paged_attention.py 3.19 KB
Newer Older
1
2
3
import torch

# vllm imports
4
from vllm._C import cache_ops, ops
5
6
7
8

_PARTITION_SIZE = 512


OlivierDehaene's avatar
OlivierDehaene committed
9
10
11
12
13
14
15
def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
16
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
17
18
19


def attention(
OlivierDehaene's avatar
OlivierDehaene committed
20
21
22
23
24
25
26
27
28
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
    input_lengths: torch.Tensor,
    max_s: int,
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
):
    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
    # Copyright 2023 The vLLM team. All rights
    # reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #

    # value_cache => [num_blocks, num_heads, head_size, block_size]
    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
OlivierDehaene's avatar
OlivierDehaene committed
50
    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
51
52
53
54
55
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
56
    use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
57
    if use_v1:
58
        ops.paged_attention_v1(
59
60
61
62
63
64
65
66
67
68
69
            out,
            query,
            key_cache,
            value_cache,
            kv_head_mapping,
            softmax_scale,
            block_tables,
            input_lengths,
            block_size,
            max_s,
            None,
70
71
            "auto",
            1.0,
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        )
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=out.dtype,
            device=out.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=out.device,
        )
        max_logits = torch.empty_like(exp_sums)
87
        ops.paged_attention_v2(
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            out,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
            kv_head_mapping,
            softmax_scale,
            block_tables,
            input_lengths,
            block_size,
            max_s,
            None,
102
103
            "auto",
            1.0,
104
        )