flashinfer.py 10.5 KB
Newer Older
1
2
3
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type

4
5
6
7
8
9
10
11
12
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
    from vllm_flash_attn import flash_attn_varlen_func
except ImportError:
    flash_attn_varlen_func = None
    BatchDecodeWithPagedKVCacheWrapper = None
    BatchPrefillWithPagedKVCacheWrapper = None

13
14
15
16
import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
17
                                              AttentionMetadata)
18
19
20
21


class FlashInferBackend(AttentionBackend):

22
23
24
25
    @staticmethod
    def get_name() -> str:
        return "flashinfer"

26
27
28
29
30
    @staticmethod
    def get_impl_cls() -> Type["FlashInferImpl"]:
        return FlashInferImpl

    @staticmethod
31
32
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return FlashInferMetadata
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
47
        src_to_dst: torch.Tensor,
48
49
50
51
52
53
    ) -> None:
        raise NotImplementedError

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
54
        src_to_dists: torch.Tensor,
55
56
57
58
59
60
61
62
63
    ) -> None:
        raise NotImplementedError

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [64, 128, 256]


@dataclass
64
65
66
67
class FlashInferMetadata(AttentionMetadata):
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
68

69
    use_cuda_graph: bool = True
70

71
    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
72
73
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None

74
    # Metadata for the prefill stage
75
    seq_start_loc: Optional[torch.Tensor] = None
76
    query_start_loc: Optional[torch.Tensor] = None
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    block_tables: Optional[torch.Tensor] = None

    # An example for paged_kv_indices, paged_kv_indptr:
    # request 1, page indices [0, 5, 8]
    # request 2, page indices [1, 6, 7]
    # request 3, page indices [3, 4]
    # paged_kv_indices is a concatenation of page indices of all requests:
    # [0, 5, 8, 1, 6, 7, 3, 4]
    # paged_kv_indptr is used to index into paged_kv_indices:
    # [0, 3, 6, 8]
    # The indptr of the paged kv cache, shape: [batch_size + 1]
    paged_kv_indptr: Optional[torch.Tensor] = None
    # The page indices of the paged kv cache
    paged_kv_indices: Optional[torch.Tensor] = None
    # The number of entries in the last page of each request in
    # the paged kv cache, shape: [batch_size]
    paged_kv_last_page_len: Optional[torch.Tensor] = None
    # The number of query/output heads
    num_qo_heads: Optional[int] = None
    # The number of key/value heads
    num_kv_heads: Optional[int] = None
    # The dimension of the attention heads
    head_dim: Optional[int] = None
    # Block size of vllm
    page_size: Optional[int] = None
    # The data type of the paged kv cache
    data_type: torch.dtype = None
104
    device: torch.device = torch.device("cuda")
105
106
107
108
109
110
111
112
113
114
115

    def __post_init__(self):
        # Refer to
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
        if self.head_dim is not None and self.head_dim \
                not in supported_head_sizes:
            raise ValueError(
                f"Only {supported_head_sizes} are supported for head_dim,",
                f"received {self.head_dim}.")

116
117
118
119
120
121
122
123
124
125
126
127
128
    def begin_forward(self):
        if self.num_prefill_tokens > 0:
            if self.paged_kv_indices is None:
                return

            assert self.prefill_wrapper is not None
            assert self.paged_kv_indices is not None
            assert self.paged_kv_indptr is not None
            assert self.paged_kv_last_page_len is not None
            self.paged_kv_indices = self.paged_kv_indices.to(self.device)
            self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
            self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                self.device)
129
            self.prefill_wrapper.end_forward()
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            self.prefill_wrapper.begin_forward(
                self.query_start_loc, self.paged_kv_indptr,
                self.paged_kv_indices, self.paged_kv_last_page_len,
                self.num_qo_heads, self.num_kv_heads, self.head_dim,
                self.page_size)
        else:
            if not self.use_cuda_graph:
                assert self.paged_kv_indices is not None
                assert self.paged_kv_indptr is not None
                assert self.paged_kv_last_page_len is not None
                self.paged_kv_indices = self.paged_kv_indices.to(self.device)
                self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
                self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
                    self.device)

            assert self.decode_wrapper is not None
146
            self.decode_wrapper.end_forward()
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            self.decode_wrapper.begin_forward(
                self.paged_kv_indptr,
                self.paged_kv_indices,
                self.paged_kv_last_page_len,
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                # Disable flashinfer's pos encoding and use vllm's rope.
                pos_encoding_mode="NONE",
                data_type=self.data_type)

    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
        if skip_fields is None:
            skip_fields = set()
164
        # We need to skip the prefill/decode_wrapper field since it cannot be
165
        # broadcasted with nccl when TP is enabled.
166
        skip_fields.add('prefill_wrapper')
167
168
169
        skip_fields.add('decode_wrapper')
        return super().asdict_zerocopy(skip_fields)

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    @property
    def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
        # Currently chunked prefill is not supported
        if self.num_decode_tokens == 0:
            assert self.num_prefills > 0
            return self

        return None

    @property
    def decode_metadata(self) -> Optional["FlashInferMetadata"]:
        # Currently chunked prefill is not supported
        if self.num_prefills > 0:
            assert self.num_decode_tokens == 0
            return None

        return self

188
189
190
191
192
193
194
195

class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
196
197
198
199
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
200
        blocksparse_params: Optional[Dict[str, Any]] = None,
201
202
203
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
204
        self.scale = float(scale)
205
        self.num_kv_heads = num_kv_heads
206
207
208
209
210
211
212
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is not None:
            raise ValueError("Sliding window is not supported in FlashInfer.")
        self.sliding_window = (-1, -1)
        self.kv_cache_dtype = kv_cache_dtype
213

214
215
216
217
218
219
220
221
222
        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Optional[torch.Tensor],
223
        attn_metadata: FlashInferMetadata,
224
225
226
        kv_scale: float = 1.0,
    ) -> torch.Tensor:
        assert kv_scale == 1.0
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        num_tokens, hidden_size = query.shape
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if attn_metadata.num_prefill_tokens > 0:
            assert attn_metadata.num_decode_tokens == 0, (
                "Chunked prefill is not supported with flashinfer yet.")
        if attn_metadata.num_decode_tokens > 0:
            assert attn_metadata.num_prefill_tokens == 0, (
                "Chunked prefill is not supported with flashinfer yet.")

        if kv_cache is not None:
            # Use the same reshape and cache kernel as flash attention.
            ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping.flatten(),
247
                self.kv_cache_dtype,
248
249
            )

250
251
        query = query.contiguous(
        )  # Flashinfer requires query to be contiguous
252
        if prefill_meta := attn_metadata.prefill_metadata:
253
254
255
256
257
            # We will use flash attention for prefill
            # when kv_cache is not provided.
            # This happens when vllm runs the profiling to
            # determine the number of blocks.
            if kv_cache is None:
258
259
260
261
262
263
                output = flash_attn_varlen_func(
                    q=query,
                    k=key,
                    v=value,
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
264
265
                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
266
267
268
269
270
271
                    softmax_scale=self.scale,
                    causal=True,
                    window_size=self.sliding_window,
                    alibi_slopes=self.alibi_slopes,
                )
            else:
272
273
274
275
276
                assert prefill_meta is not None
                assert prefill_meta.prefill_wrapper is not None
                output = prefill_meta.prefill_wrapper.forward(query,
                                                              kv_cache,
                                                              causal=True)
277
278
279
280
281
282
283
284
285
        else:
            assert attn_metadata.decode_metadata is not None
            assert attn_metadata.decode_metadata.decode_wrapper is not None
            output = attn_metadata.decode_metadata.decode_wrapper.forward(
                query,
                kv_cache,
                sm_scale=self.scale,
            )
        return output.view(num_tokens, hidden_size)