input_batch.py 10.6 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any

import numba
import numba.types as types
import numpy as np
import torch

11
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer


class InputBuffers:
    def __init__(
        self,
        max_num_reqs: int,
        max_num_tokens: int,
        hidden_size: int,
        vocab_size: int,
        dtype: torch.dtype,
        device: torch.device,
        pin_memory: bool,
    ):
        self.max_num_reqs = max_num_reqs
        self.max_num_tokens = max_num_tokens
        self.device = device
        self.pin_memory = pin_memory

        self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
        self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
35
        self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
        self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
37
        self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        # Structured outputs.
        self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
        self.grammar_bitmask = self._make_buffer(
            max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
        )

    def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
        )


@dataclass
class InputBatch:
    # batch_idx -> req_id
    req_ids: list[str]
    num_reqs: int

    # batch_idx -> req_state_idx
    idx_mapping: torch.Tensor
    idx_mapping_np: np.ndarray

    # [num_reqs]
    # batch_idx -> num_scheduled_tokens
    num_scheduled_tokens: np.ndarray
    # sum(num_scheduled_tokens)
    num_tokens: int
    num_tokens_after_padding: int

    # [num_reqs + 1]
    query_start_loc: torch.Tensor
    query_start_loc_np: np.ndarray
    # [num_reqs]
    seq_lens: torch.Tensor
    seq_lens_np: np.ndarray

    # [num_tokens_after_padding]
    input_ids: torch.Tensor
    # [num_tokens_after_padding]
    positions: torch.Tensor

    # layer_name -> Metadata
    attn_metadata: dict[str, Any]

    # [num_reqs]
    logits_indices: torch.Tensor

    @classmethod
    def make_dummy(
        cls,
        num_reqs: int,
        num_tokens: int,
        input_buffers: InputBuffers,
        device: torch.device,
    ) -> "InputBatch":
        assert 0 < num_reqs <= num_tokens
        req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
        idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
        num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
        num_scheduled_tokens[-1] += num_tokens % num_reqs
        assert int(num_scheduled_tokens.sum()) == num_tokens

        input_buffers.query_start_loc.np[0] = 0
        input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
            num_scheduled_tokens
        )
        input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
        query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
        query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
        # seq_len equals to query_len
110
111
112
113
114
115
        seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
        seq_lens_np[-1] += num_tokens % num_reqs
        input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
        input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
        input_buffers.seq_lens[num_reqs:] = 0
        seq_lens = input_buffers.seq_lens[:num_reqs]
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117

        input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
118
        positions = input_buffers.positions[:num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        # attn_metadata = defaultdict(lambda: None)
        logits_indices = query_start_loc[1:] - 1
        return cls(
            req_ids=req_ids,
            num_reqs=num_reqs,
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens,
            query_start_loc=query_start_loc,
            query_start_loc_np=query_start_loc_np,
            seq_lens=seq_lens,
            seq_lens_np=seq_lens_np,
            input_ids=input_ids,
            positions=positions,
            attn_metadata=None,  # type: ignore
            logits_indices=logits_indices,
        )


# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
    [
        types.none(
            types.int32[:],  # idx_mapping
            types.int32[:],  # num_scheduled_tokens
147
148
149
            types.int32[:, :],  # prefill_token_ids
            types.int32[:],  # num_computed_prefill_tokens
            types.int32[:],  # prefill_len
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152
153
154
155
156
            types.int32[:],  # input_ids
            types.int32[:],  # query_start_loc
        )
    ],
    nopython=True,
    cache=True,
)
157
def _prepare_prefill_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
    idx_mapping: np.ndarray,  # batch_idx -> req_idx
    num_scheduled_tokens: np.ndarray,  # [B]
160
161
162
    prefill_token_ids: np.ndarray,  # [N, max_model_len]
    num_computed_prefill_tokens: np.ndarray,  # [N]
    prefill_len: np.ndarray,  # [N]
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
166
167
168
169
170
171
172
    input_ids: np.ndarray,  # [num_input_tokens]
    query_start_loc: np.ndarray,  # [B + 1]
) -> None:
    num_reqs = num_scheduled_tokens.shape[0]
    query_start_loc[0] = 0

    cu_num_tokens = 0
    for i in range(num_reqs):
        req_idx = idx_mapping[i]
        query_len = num_scheduled_tokens[i]
173
174
175
176

        start = num_computed_prefill_tokens[req_idx]
        end = min(start + query_len, prefill_len[req_idx])
        n = end - start
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178

        start_idx = cu_num_tokens
179
        input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end]
Woosuk Kwon's avatar
Woosuk Kwon committed
180

181
        cu_num_tokens = start_idx + query_len
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
184
185
186
187
188
189
        query_start_loc[i + 1] = cu_num_tokens

    # Pad the inputs for CUDA graphs.
    # Note: pad query_start_loc to be non-decreasing, as kernels
    # like FlashAttention requires that
    query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)


190
def prepare_prefill_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
    idx_mapping: np.ndarray,
    num_scheduled_tokens: np.ndarray,
193
194
195
196
    total_num_tokens: int,
    prefill_token_ids: np.ndarray,
    num_computed_prefill_tokens: np.ndarray,
    prefill_len: np.ndarray,
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
    input_ids: CpuGpuBuffer,
    query_start_loc: CpuGpuBuffer,
) -> None:
200
    _prepare_prefill_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
        idx_mapping,
        num_scheduled_tokens,
203
204
205
        prefill_token_ids,
        num_computed_prefill_tokens,
        prefill_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
        input_ids.np,
        query_start_loc.np,
    )
209
    input_ids.copy_to_gpu(total_num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212
213
214
215
216
    # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
    # tensors from CPU to GPU, because they may include paddings needed
    # for full CUDA graph mode.
    query_start_loc.copy_to_gpu()


@triton.jit
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def _prepare_pos_seq_lens_kernel(
    pos_ptr,
    seq_lens_ptr,
    idx_mapping_ptr,
    query_start_loc_ptr,
    num_computed_tokens_ptr,
    max_num_reqs,
    BLOCK_SIZE: tl.constexpr,
):
    req_id = tl.program_id(0)
    num_reqs = tl.num_programs(0) - 1
    if req_id == num_reqs:
        # Pad unused seq_lens as 0 for full CUDA graphs.
        for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
            block = i + tl.arange(0, BLOCK_SIZE)
            mask = block < max_num_reqs
            tl.store(seq_lens_ptr + block, 0, mask=mask)
        return

    req_state_idx = tl.load(idx_mapping_ptr + req_id)
    num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)

    start = tl.load(query_start_loc_ptr + req_id)
    end = tl.load(query_start_loc_ptr + req_id + 1)
    query_len = end - start

    seq_len = num_computed_tokens + query_len
    tl.store(seq_lens_ptr + req_id, seq_len)

    for i in tl.range(0, query_len, BLOCK_SIZE):
        block = i + tl.arange(0, BLOCK_SIZE)
        mask = block < query_len
        pos = num_computed_tokens + block
        tl.store(pos_ptr + start + block, pos, mask=mask)


def prepare_pos_seq_lens(
    idx_mapping: torch.Tensor,
    query_start_loc: torch.Tensor,
    num_computed_tokens: torch.Tensor,
    pos: torch.Tensor,
    seq_lens: torch.Tensor,
) -> None:
    num_reqs = idx_mapping.shape[0]
    # NOTE(woosuk): We do +1 because the last thread block is used
    # to pad unused seq_lens as 0 for full CUDA graphs.
    _prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
        pos,
        seq_lens,
        idx_mapping,
        query_start_loc,
        num_computed_tokens,
        seq_lens.shape[0],
        BLOCK_SIZE=1024,
    )


@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
    input_ids_ptr,
    idx_mapping_ptr,
278
    last_sampled_tokens_ptr,
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281
282
283
284
285
286
287
288
289
290
291
    query_start_loc_ptr,
    seq_lens_ptr,
    prefill_len_ptr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

    seq_len = tl.load(seq_lens_ptr + batch_idx)
    prefill_len = tl.load(prefill_len_ptr + req_state_idx)
    if seq_len <= prefill_len:
        # Handling prefill tokens.
        return

292
    last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
295
296
    end = tl.load(query_start_loc_ptr + batch_idx + 1)
    tl.store(input_ids_ptr + end - 1, last_token_id)


297
def combine_sampled_and_draft_tokens(
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
    input_ids: torch.Tensor,
    idx_mapping: torch.Tensor,
300
    last_sampled_tokens: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
304
305
    query_start_loc: torch.Tensor,
    seq_lens: torch.Tensor,
    prefill_len: torch.Tensor,
) -> torch.Tensor:
    num_reqs = seq_lens.shape[0]
306
    _combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
Woosuk Kwon's avatar
Woosuk Kwon committed
307
308
        input_ids,
        idx_mapping,
309
        last_sampled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
310
311
312
313
314
        query_start_loc,
        seq_lens,
        prefill_len,
    )
    return input_ids
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344


@triton.jit
def _update_num_computed_tokens_kernel(
    idx_mapping_ptr,
    num_computed_tokens_ptr,
    query_start_loc_ptr,
):
    req_id = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + req_id)

    start = tl.load(query_start_loc_ptr + req_id)
    end = tl.load(query_start_loc_ptr + req_id + 1)
    query_len = end - start

    n = tl.load(num_computed_tokens_ptr + req_state_idx)
    tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len)


def update_num_computed_tokens(
    idx_mapping: torch.Tensor,
    num_computed_tokens: torch.Tensor,
    query_start_loc: torch.Tensor,
) -> None:
    num_reqs = idx_mapping.shape[0]
    _update_num_computed_tokens_kernel[(num_reqs,)](
        idx_mapping,
        num_computed_tokens,
        query_start_loc,
    )