flashinfer.py 7.47 KB
Newer Older
jixx's avatar
jixx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager

import flashinfer
import torch

prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
    "prefill_state"
)

prefill_with_paged_kv_state: ContextVar[
    flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")

decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
    "decode_state"
)

workspace: Optional[torch.Tensor] = None


def get_workspace(device):
    """Get shared flashinfer workspace."""
    global workspace
    if workspace is None:
        workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
    return workspace


def create_prefill_with_paged_kv_state(
    *,
    device: torch.device,
):
    """Create a prefill state that uses the KV cache."""
    workspace_buffer = get_workspace(device)
    return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
        workspace_buffer, kv_layout="NHD", use_cuda_graph=False
    )


@contextmanager
def use_prefill_with_paged_kv_state(
    *,
    state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
    block_tables: torch.Tensor,
    cu_seqlens: torch.Tensor,
    input_lengths: torch.Tensor,
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    page_size: int,
    dtype: torch.dtype,
    window_left: int,
):
    """
    Context manager to set the active flashinfer prefill state to the given
    `state` and parameters. This state will be used by all calls to the
    `attention` function while the context manager is active.
    """

    indptr = torch.zeros(
        input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
    )
    # Round up to page size and then calculate the cumulative sum to get
    # the indices into the block table.
    torch.add(input_lengths, page_size - 1, out=indptr[1:])
    indptr[1:].div_(page_size, rounding_mode="floor")
    indptr[1:].cumsum_(-1)

    # Get the lengths of the last page in a block.
    if page_size == 1:
        last_page_len = torch.ones(
            input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
        )
    else:
        last_page_len = torch.empty(
            input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
        )
        torch.sub(input_lengths, 1, out=last_page_len)
        last_page_len.remainder_(page_size)
        last_page_len += 1

    token = prefill_with_paged_kv_state.set(state)
    try:
        state.begin_forward(
            qo_indptr=cu_seqlens,
            paged_kv_indptr=indptr,
            paged_kv_indices=block_tables,
            paged_kv_last_page_len=last_page_len,
            num_qo_heads=num_heads,
            num_kv_heads=num_kv_heads,
            head_dim=head_size,
            q_data_type=dtype,
            page_size=page_size,
            window_left=window_left,
        )
        yield
    finally:
        state.end_forward()
        if token is not None:
            prefill_with_paged_kv_state.reset(token)


def create_prefill_state(
    *,
    device: torch.device,
):
    """Create a prefill state."""
    workspace_buffer = get_workspace(device)
    return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
        workspace_buffer, kv_layout="NHD", use_cuda_graph=False
    )


@contextmanager
def use_prefill_state(
    *,
    state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
    cu_seqlens: torch.Tensor,
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    window_left: int,
):
    """
    Context manager to set the active flashinfer prefill state to the given
    `state` and parameters. This state will be used by all calls to the
    `attention` function while the context manager is active.
    """

    token = prefill_state.set(state)
    try:
        state.begin_forward(
            qo_indptr=cu_seqlens,
            kv_indptr=cu_seqlens,
            num_qo_heads=num_heads,
            num_kv_heads=num_kv_heads,
            head_dim=head_size,
            q_data_type=dtype,
            window_left=window_left,
        )
        yield
    finally:
        state.end_forward()
        if token is not None:
            prefill_state.reset(token)


def create_decode_state(
    *,
    device: torch.device,
    num_heads: int,
    num_kv_heads: int,
):
    """Create a decode state."""
    workspace_buffer = get_workspace(device)
    num_groups = num_heads // num_kv_heads
    return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer,
        kv_layout="NHD",
        use_cuda_graph=False,
        # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
        use_tensor_cores=num_groups not in [1, 2, 4, 8],
    )


def create_decode_state_cuda_graphs(
    *,
    device: torch.device,
    block_tables: torch.Tensor,
    block_tables_ptr: torch.Tensor,
    last_page_len: torch.Tensor,
    num_heads: int,
    num_kv_heads: int,
):
    """
    Create a decode state for use with CUDA Graphs. `block_tables`,
    `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
    therefore stored as part of the state.
    """
    workspace_buffer = get_workspace(device)
    num_groups = num_heads // num_kv_heads
    return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer,
        kv_layout="NHD",
        use_cuda_graph=True,
        paged_kv_indices_buffer=block_tables,
        paged_kv_indptr_buffer=block_tables_ptr,
        paged_kv_last_page_len_buffer=last_page_len,
        # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
        use_tensor_cores=num_groups not in [1, 2, 4, 8],
    )


@contextmanager
def use_decode_state(
    *,
    state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
    input_lengths: torch.Tensor,
    block_tables: torch.Tensor,
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    page_size: int,
    kv_cache_dtype: torch.dtype,
    dtype: torch.dtype,
    window_left: int,
):
    """
    Context manager to set the active flashinfer decoding state to the given
    `state` and parameters. This state will be used by all calls to the
    `paged_attention` function while the context manager is active.
    """
    indptr = torch.zeros(
        input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
    )
    # Round up to page size and then calculate the cumulative sum to get
    # the indices into the block table.
    torch.add(input_lengths, page_size - 1, out=indptr[1:])
    indptr[1:].div_(page_size, rounding_mode="floor")
    indptr[1:].cumsum_(-1)

    # Get the lengths of the last page in a block.
    last_page_len = torch.empty(
        input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
    )
    torch.sub(input_lengths, 1, out=last_page_len)
    last_page_len.remainder_(page_size)
    last_page_len += 1

    token = decode_state.set(state)

    try:
        state.begin_forward(
            indptr=indptr,
            indices=block_tables,
            last_page_len=last_page_len,
            num_qo_heads=num_heads,
            num_kv_heads=num_kv_heads,
            head_dim=head_size,
            page_size=page_size,
            data_type=kv_cache_dtype,
            q_data_type=dtype,
            window_left=window_left,
        )
        yield
    finally:
        state.end_forward()
        if token is not None:
            decode_state.reset(token)