gpu_input_batch.py 20.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
# Datastructures defining an input batch

from dataclasses import dataclass
6
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
13
14
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
15
from vllm.v1.worker.block_table import BlockTable
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

if TYPE_CHECKING:
    from vllm.multimodal.inputs import PlaceholderRange


@dataclass
class CachedRequestState:

    req_id: str
    prompt_token_ids: List[int]
    prompt: Optional[str]
    mm_inputs: List[MultiModalKwargs]
    mm_positions: List["PlaceholderRange"]
    sampling_params: SamplingParams
    generator: Optional[torch.Generator]

    block_ids: List[int]
    num_computed_tokens: int
    output_token_ids: List[int]

36
37
38
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[int] = None

39
40
    lora_request: Optional[LoRARequest] = None

41
42
43
44
45
46
47
48
49
50
51
52
53
54
    @property
    def num_tokens(self) -> int:
        return len(self.prompt_token_ids) + len(self.output_token_ids)


class InputBatch:

    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_blocks_per_req: int,
        device: torch.device,
        pin_memory: bool,
55
        vocab_size: int,
56
57
58
59
60
61
    ):
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
        self.max_num_blocks_per_req = max_num_blocks_per_req
        self.device = device
        self.pin_memory = pin_memory
62
        self.vocab_size = vocab_size
63
64
65
66

        self.req_ids: List[Optional[str]] = [None] * max_num_reqs
        self.req_id_to_index: Dict[str, int] = {}

67
68
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
69
70
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
71
72
73
74
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
75
            pin_memory=False,
76
77
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
78
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
79
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
80
        self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
81

82
83
84
85
86
        # Block table.
        self.block_table = BlockTable(
            max_num_reqs=max_num_reqs,
            max_model_len=max_model_len,
            max_num_blocks_per_req=max_num_blocks_per_req,
87
            pin_memory=pin_memory,
88
            device=device,
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        )

        # Sampling-related.
        self.temperature = torch.empty((max_num_reqs, ),
                                       dtype=torch.float32,
                                       device=device)
        self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
                                                  dtype=torch.float32,
                                                  device="cpu",
                                                  pin_memory=pin_memory)
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
        self.greedy_reqs: Set[str] = set()
        self.random_reqs: Set[str] = set()

        self.top_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
        self.top_p_reqs: Set[str] = set()

        self.top_k = torch.empty((max_num_reqs, ),
                                 dtype=torch.int32,
                                 device=device)
        self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
        self.top_k_reqs: Set[str] = set()

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        # Frequency penalty related data structures
        self.frequency_penalties = torch.empty((max_num_reqs, ),
                                               dtype=torch.float,
                                               device=device)
        self.frequency_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.frequency_penalties_cpu = \
            self.frequency_penalties_cpu_tensor.numpy()
        self.frequency_penalties_reqs: Set[str] = set()

        # Presence penalty related data structures
        self.presence_penalties = torch.empty((max_num_reqs, ),
                                              dtype=torch.float,
                                              device=device)
        self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
                                                         dtype=torch.float,
                                                         device="cpu",
                                                         pin_memory=pin_memory)
        self.presence_penalties_cpu = \
            self.presence_penalties_cpu_tensor.numpy()
        self.presence_penalties_reqs: Set[str] = set()

        # Repetition penalty related data structures
        self.repetition_penalties = torch.empty((max_num_reqs, ),
                                                dtype=torch.float,
                                                device=device)
        self.repetition_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.repetition_penalties_cpu = \
            self.repetition_penalties_cpu_tensor.numpy()
        self.repetition_penalties_reqs: Set[str] = set()

        self.min_tokens: List[int] = [0] * max_num_reqs
        self.stop_token_ids: List[Set[int]] = [
            set() for _ in range(max_num_reqs)
        ]
        self.prompt_token_ids: Optional[torch.Tensor] = None

167
168
169
170
171
172
        # lora related
        self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
                                             dtype=np.int32)
        self.lora_id_to_request_ids: Dict[int, Set[str]] = {}
        self.lora_id_to_lora_request: Dict[int, LoRARequest] = {}

173
        # req_index -> generator
174
175
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        self.generators: Dict[int, torch.Generator] = {}

        self.num_logprobs: Dict[str, int] = {}
        self.prompt_logprob_reqs: Set[str] = set()

    def add_request(
        self,
        request: "CachedRequestState",
        req_index: Optional[int] = None,
    ) -> None:
        if req_index is None:
            req_index = self.num_reqs
        assert req_index < self.max_num_reqs

        req_id = request.req_id
        self.req_ids[req_index] = req_id
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
        num_prompt_tokens = len(request.prompt_token_ids)
196
        self.num_prompt_tokens[req_index] = num_prompt_tokens
197
198
199
200
201
202
        self.token_ids_cpu[
            req_index, :num_prompt_tokens] = request.prompt_token_ids
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
        self.token_ids_cpu[req_index,
                           start_idx:end_idx] = request.output_token_ids
203
        self.num_tokens[req_index] = request.num_tokens
204
205

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
206
        self.block_table.add_row(req_index, request.block_ids)
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        sampling_params = request.sampling_params
        self.temperature_cpu[req_index] = sampling_params.temperature
        if sampling_params.sampling_type == SamplingType.GREEDY:
            self.greedy_reqs.add(req_id)
        else:
            self.random_reqs.add(req_id)

        self.top_p_cpu[req_index] = sampling_params.top_p
        if sampling_params.top_p < 1:
            self.top_p_reqs.add(req_id)
        self.top_k_cpu[req_index] = sampling_params.top_k
        if sampling_params.top_k > 0:
            self.top_k_reqs.add(req_id)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        self.frequency_penalties_cpu[req_index] = \
            sampling_params.frequency_penalty
        if sampling_params.frequency_penalty != 0.0:
            self.frequency_penalties_reqs.add(req_id)
        self.presence_penalties_cpu[req_index] = \
            sampling_params.presence_penalty
        if sampling_params.presence_penalty != 0.0:
            self.presence_penalties_reqs.add(req_id)
        self.repetition_penalties_cpu[req_index] = \
            sampling_params.repetition_penalty
        if sampling_params.repetition_penalty != 1.0:
            self.repetition_penalties_reqs.add(req_id)
        self.min_tokens[req_index] = sampling_params.min_tokens
        self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
235

236
237
238
239
        # NOTE(woosuk): self.generators should not include the requests that
        # do not have their own generator.
        if request.generator is not None:
            self.generators[req_index] = request.generator
240
241
242
243
244
245
246

        num_logprobs = sampling_params.logprobs
        if num_logprobs is not None and num_logprobs > 0:
            self.num_logprobs[req_id] = num_logprobs
        if sampling_params.prompt_logprobs:
            self.prompt_logprob_reqs.add(req_id)

247
248
249
250
251
252
253
254
255
256
257
258
259
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

260
261
262
263
264
265
266
267
268
269
    def remove_request(self, req_id: str) -> Optional[int]:
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
        self.req_ids[req_index] = None

        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
270
271
272
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
273
274
275
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
        self.prompt_logprob_reqs.discard(req_id)
276
277
278
279
280
281
282
283
284
285

        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            self.lora_id_to_request_ids[lora_id].discard(req_id)
            if len(self.lora_id_to_request_ids[lora_id]) == 0:
                self.lora_id_to_request_ids.pop(lora_id)
                self.lora_id_to_lora_request.pop(lora_id)
            self.request_lora_mapping[req_index] = 0

286
287
288
289
290
291
292
293
294
        return req_index

    def clear(self) -> None:
        self.req_ids = [None] * self.max_num_reqs
        self.req_id_to_index.clear()
        self.greedy_reqs.clear()
        self.random_reqs.clear()
        self.top_p_reqs.clear()
        self.top_k_reqs.clear()
295
296
297
        self.frequency_penalties_reqs.clear()
        self.presence_penalties_reqs.clear()
        self.repetition_penalties_reqs.clear()
298
299
300
        self.generators.clear()
        self.num_logprobs.clear()
        self.prompt_logprob_reqs.clear()
301
302
303
        self.request_lora_mapping.fill(0)
        self.lora_id_to_lora_request.clear()
        self.lora_id_to_request_ids.clear()
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    def condense(self, empty_req_indices: List[int]) -> None:
        if self.num_reqs == 0:
            # The batched states are empty.
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
        last_req_index = self.num_reqs + len(empty_req_indices) - 1
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
            empty_index = empty_req_indices.pop()
            if empty_index >= last_req_index:
                break

            # Swap the states.
            req_id = self.req_ids[last_req_index]
325
            assert req_id is not None
326
327
328
329
            self.req_ids[empty_index] = req_id
            self.req_ids[last_req_index] = None
            self.req_id_to_index[req_id] = empty_index

330
331
332
333
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
                last_req_index, :num_tokens]
            self.num_tokens[empty_index] = num_tokens
334
335
            self.num_prompt_tokens[empty_index] = \
                self.num_prompt_tokens[last_req_index]
336
337
            self.num_computed_tokens_cpu[
                empty_index] = self.num_computed_tokens_cpu[last_req_index]
338
            self.block_table.move_row(last_req_index, empty_index)
339
340
341
342
            self.temperature_cpu[empty_index] = self.temperature_cpu[
                last_req_index]
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
343
344
345
346
347
348
349
350
351
            self.frequency_penalties_cpu[empty_index] = \
                self.frequency_penalties_cpu[last_req_index]
            self.presence_penalties_cpu[empty_index] = \
                self.presence_penalties_cpu[last_req_index]
            self.repetition_penalties_cpu[empty_index] = \
                self.repetition_penalties_cpu[last_req_index]
            self.min_tokens[empty_index] = self.min_tokens[last_req_index]
            self.stop_token_ids[empty_index] = \
                self.stop_token_ids[last_req_index]
352
353
354
355
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

356
357
358
            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
                last_req_index]

359
360
361
362
363
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

    def make_sampling_metadata(
        self,
364
        req_id_output_token_ids: Dict[str, List[int]],
365
366
367
368
369
370
371
372
373
        skip_copy: bool = False,
    ) -> SamplingMetadata:
        if not skip_copy:
            self.temperature[:self.num_reqs].copy_(
                self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
            self.top_p[:self.num_reqs].copy_(
                self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
            self.top_k[:self.num_reqs].copy_(
                self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            if not self.no_penalties:
                # Since syncing these tensors is expensive only copy them
                # if necessary i.e. if there are requests which require
                # penalties to be applied during sampling.
                self.frequency_penalties[:self.num_reqs].copy_(
                    self.frequency_penalties_cpu_tensor[:self.num_reqs],
                    non_blocking=True)
                self.presence_penalties[:self.num_reqs].copy_(
                    self.presence_penalties_cpu_tensor[:self.num_reqs],
                    non_blocking=True)
                self.repetition_penalties[:self.num_reqs].copy_(
                    self.repetition_penalties_cpu_tensor[:self.num_reqs],
                    non_blocking=True)
                # The prompt tokens are used only for applying penalties during
                # the sampling process. Hence copy these tensors only when
                # there are requests which need penalties to be applied.
                self.prompt_token_ids = self._make_prompt_token_ids_tensor()

        output_token_ids: List[List[int]] = []

        for req_id in self.req_ids[:self.num_reqs]:
            assert req_id is not None
            # Currently we create a tensor for output_token_ids from scratch
            # at each step. However, for the penalties computation what we
            # need is stats about the token ids present in the output. This
            # stats can be maintained incrementally instead of computing it
            # from scratch at each step.
            # TODO - Replace this with incremental update to output token
            # statistics.
            output_token_ids.append(req_id_output_token_ids[req_id])

405
406
407
408
409
410
411
412
413
414
        return SamplingMetadata(
            temperature=self.temperature[:self.num_reqs],
            all_greedy=self.all_greedy,
            all_random=self.all_random,
            top_p=self.top_p[:self.num_reqs],
            top_k=self.top_k[:self.num_reqs],
            no_top_p=self.no_top_p,
            no_top_k=self.no_top_k,
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
415
416
417
418
419
420
421
422
            prompt_token_ids=self.prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:self.num_reqs],
            presence_penalties=self.presence_penalties[:self.num_reqs],
            repetition_penalties=self.repetition_penalties[:self.num_reqs],
            output_token_ids=output_token_ids,
            min_tokens=self.min_tokens[:self.num_reqs],
            stop_token_ids=self.stop_token_ids[:self.num_reqs],
            no_penalties=self.no_penalties,
423
424
        )

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
        max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
            pin_memory=self.pin_memory)
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
        prompt_token_ids[:] = (
            self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
        for i in range(self.num_reqs):
            prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device,
                                              non_blocking=True)

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray
    ) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size self.num_reqs where,
               prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

        req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
        prompt_lora_mapping = tuple(req_lora_mapping)
        token_lora_mapping = tuple(
            req_lora_mapping.repeat(num_scheduled_tokens))
        active_lora_requests: Set[LoRARequest] = set(
            self.lora_id_to_lora_request.values())

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

485
486
487
488
489
490
    @property
    def no_penalties(self) -> bool:
        return (len(self.presence_penalties_reqs) == 0
                and len(self.frequency_penalties_reqs) == 0
                and len(self.repetition_penalties_reqs) == 0)

491
492
493
494
495
496
497
498
499
500
501
    @property
    def max_num_logprobs(self) -> int:
        return max(self.num_logprobs.values()) if self.num_logprobs else 0

    @property
    def no_logprob(self) -> bool:
        return len(self.num_logprobs) == 0

    @property
    def no_prompt_logprob(self) -> bool:
        return len(self.prompt_logprob_reqs) == 0