paged_attention.py 5.45 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
from text_generation_server.utils.import_utils import SYSTEM
3
4
5

_PARTITION_SIZE = 512

Nicolas Patry's avatar
Nicolas Patry committed
6
if SYSTEM == "xpu":
7
8
    import intel_extension_for_pytorch as ipex

9

OlivierDehaene's avatar
OlivierDehaene committed
10
11
12
13
14
15
16
def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
Nicolas Patry's avatar
Nicolas Patry committed
17
    if SYSTEM == "cuda":
18
19
20
21
22
        from vllm._C import cache_ops

        cache_ops.reshape_and_cache(
            key, value, key_cache, value_cache, slots, "auto", 1.0
        )
Nicolas Patry's avatar
Nicolas Patry committed
23
    elif SYSTEM == "rocm":
24
25
26
        from vllm import cache_ops

        cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
Nicolas Patry's avatar
Nicolas Patry committed
27
    elif SYSTEM == "xpu":
28
29
30
        ipex.llm.modules.PagedAttention.reshape_and_cache(
            key, value, key_cache, value_cache, slots
        )
31
32
    else:
        raise ValueError("vllm is not supported on your system")
33
34
35


def attention(
OlivierDehaene's avatar
OlivierDehaene committed
36
37
38
39
40
41
42
43
44
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    kv_head_mapping: torch.Tensor,
    softmax_scale: float,
    block_tables: torch.Tensor,
    input_lengths: torch.Tensor,
    max_s: int,
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
):
    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
    # Copyright 2023 The vLLM team. All rights
    # reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #

    # value_cache => [num_blocks, num_heads, head_size, block_size]
    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
OlivierDehaene's avatar
OlivierDehaene committed
66
    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
Nicolas Patry's avatar
Nicolas Patry committed
67
    if SYSTEM == "xpu":
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        query = query.contiguous()
        return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
            out,
            query,
            key_cache,
            value_cache,
            kv_head_mapping,
            softmax_scale,
            block_tables,
            input_lengths,
            block_size,
            max_s,
            None,
        )

83
84
85
86
87
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
88
    use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
89
    if use_v1:
Nicolas Patry's avatar
Nicolas Patry committed
90
        if SYSTEM == "cuda":
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            from vllm._C import ops

            ops.paged_attention_v1(
                out,
                query,
                key_cache,
                value_cache,
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
                1.0,
            )
Nicolas Patry's avatar
Nicolas Patry committed
108
        elif SYSTEM == "rocm":
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            from vllm import attention_ops

            attention_ops.paged_attention_v1(
                out,
                query,
                key_cache,
                value_cache,
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
            )
        else:
            raise ValueError("vllm is not supported on your system")

127
128
129
130
131
132
133
134
135
136
137
138
139
140
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=out.dtype,
            device=out.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=out.device,
        )
        max_logits = torch.empty_like(exp_sums)
141

Nicolas Patry's avatar
Nicolas Patry committed
142
        if SYSTEM == "cuda":
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            from vllm._C import ops

            ops.paged_attention_v2(
                out,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
                "auto",
                1.0,
            )
Nicolas Patry's avatar
Nicolas Patry committed
163
        elif SYSTEM == "rocm":
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            from vllm import attention_ops

            attention_ops.paged_attention_v2(
                out,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                kv_head_mapping,
                softmax_scale,
                block_tables,
                input_lengths,
                block_size,
                max_s,
                None,
            )
        else:
            raise ValueError("vllm is not supported on your system")