gpu_input_batch.py 27.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
# Datastructures defining an input batch

from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Optional, cast
6
7
8
9

import numpy as np
import torch

10
from vllm.lora.request import LoRARequest
11
12
13
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
14
from vllm.v1.utils import copy_slice
15
from vllm.v1.worker.block_table import BlockTable
16

17
18
_SAMPLING_EPS = 1e-5

19
20
21
22
23
24
25
26
if TYPE_CHECKING:
    from vllm.multimodal.inputs import PlaceholderRange


@dataclass
class CachedRequestState:

    req_id: str
27
    prompt_token_ids: list[int]
28
    prompt: Optional[str]
29
30
    mm_inputs: list[MultiModalKwargs]
    mm_positions: list["PlaceholderRange"]
31
32
33
    sampling_params: SamplingParams
    generator: Optional[torch.Generator]

34
    block_ids: list[int]
35
    num_computed_tokens: int
36
    output_token_ids: list[int]
37

38
39
40
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[int] = None

41
42
    lora_request: Optional[LoRARequest] = None

43
44
45
46
47
48
49
50
51
52
53
54
55
56
    @property
    def num_tokens(self) -> int:
        return len(self.prompt_token_ids) + len(self.output_token_ids)


class InputBatch:

    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_blocks_per_req: int,
        device: torch.device,
        pin_memory: bool,
57
        vocab_size: int,
58
59
60
61
62
63
    ):
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
        self.max_num_blocks_per_req = max_num_blocks_per_req
        self.device = device
        self.pin_memory = pin_memory
64
        self.vocab_size = vocab_size
65

66
67
        self._req_ids: list[Optional[str]] = []
        self.req_id_to_index: dict[str, int] = {}
68

69
70
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
71
72
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
73
74
75
76
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
77
            pin_memory=False,
78
79
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
80
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
81
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
82
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
83
84
85
86
87
88
89
90
        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs, ),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_computed_tokens_cpu = \
            self.num_computed_tokens_cpu_tensor.numpy()
91

92
93
94
95
        # Block table.
        self.block_table = BlockTable(
            max_num_reqs=max_num_reqs,
            max_num_blocks_per_req=max_num_blocks_per_req,
96
            pin_memory=pin_memory,
97
            device=device,
98
99
100
101
102
103
104
105
106
107
108
        )

        # Sampling-related.
        self.temperature = torch.empty((max_num_reqs, ),
                                       dtype=torch.float32,
                                       device=device)
        self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
                                                  dtype=torch.float32,
                                                  device="cpu",
                                                  pin_memory=pin_memory)
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
109
110
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
111
112
113
114
115
116
117
118
119

        self.top_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
120
        self.top_p_reqs: set[str] = set()
121
122
123
124
125
126
127
128
129

        self.top_k = torch.empty((max_num_reqs, ),
                                 dtype=torch.int32,
                                 device=device)
        self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
130
        self.top_k_reqs: set[str] = set()
131

132
133
134
135
136
137
138
139
        self.min_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.min_p_cpu = self.min_p_cpu_tensor.numpy()
140
        self.min_p_reqs: set[str] = set()
141

142
143
144
145
146
147
148
149
150
151
        # Frequency penalty related data structures
        self.frequency_penalties = torch.empty((max_num_reqs, ),
                                               dtype=torch.float,
                                               device=device)
        self.frequency_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.frequency_penalties_cpu = \
152
            self.frequency_penalties_cpu_tensor.numpy()
153
        self.frequency_penalties_reqs: set[str] = set()
154
155
156
157
158
159
160
161
162

        # Presence penalty related data structures
        self.presence_penalties = torch.empty((max_num_reqs, ),
                                              dtype=torch.float,
                                              device=device)
        self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
                                                         dtype=torch.float,
                                                         device="cpu",
                                                         pin_memory=pin_memory)
163
164
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
        )
165
        self.presence_penalties_reqs: set[str] = set()
166
167
168
169
170
171
172
173
174
175
176

        # Repetition penalty related data structures
        self.repetition_penalties = torch.empty((max_num_reqs, ),
                                                dtype=torch.float,
                                                device=device)
        self.repetition_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.repetition_penalties_cpu = \
177
            self.repetition_penalties_cpu_tensor.numpy()
178
        self.repetition_penalties_reqs: set[str] = set()
179

180
        # req_index -> (min_tokens, stop_token_ids)
181
        self.min_tokens: dict[int, tuple[int, set[int]]] = {}
182

183
184
185
        # lora related
        self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
                                             dtype=np.int32)
186
187
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
188

189
        # req_index -> generator
190
191
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
192
        self.generators: dict[int, torch.Generator] = {}
193

194
        self.num_logprobs: dict[str, int] = {}
195
196
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
197
        self.num_prompt_logprobs: dict[str, int] = {}
198

199
        self.logit_bias: list[Optional[dict[int,
200
                                            float]]] = [None] * max_num_reqs
201
        self.has_allowed_token_ids: set[str] = set()
202
203
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
204
205
        self.allowed_token_ids_mask: Optional[torch.Tensor] = None
        self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
206

207
        self.req_output_token_ids: list[Optional[list[int]]] = []
208
209
210
211
212

        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

    @property
213
    def req_ids(self) -> list[str]:
214
215
        # None elements should only be present transiently
        # while performing state updates to the batch.
216
        return cast(list[str], self._req_ids)
217

218
219
220
221
222
223
224
225
226
227
    def add_request(
        self,
        request: "CachedRequestState",
        req_index: Optional[int] = None,
    ) -> None:
        if req_index is None:
            req_index = self.num_reqs
        assert req_index < self.max_num_reqs

        req_id = request.req_id
228
229
230
231
232
233
234
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids

235
236
237
238
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
        num_prompt_tokens = len(request.prompt_token_ids)
239
        self.num_prompt_tokens[req_index] = num_prompt_tokens
240
241
242
243
244
245
        self.token_ids_cpu[
            req_index, :num_prompt_tokens] = request.prompt_token_ids
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
        self.token_ids_cpu[req_index,
                           start_idx:end_idx] = request.output_token_ids
246
247
        # Number of token ids in token_ids_cpu.
        # NOTE(woosuk): This may include spec decode tokens.
248
        self.num_tokens[req_index] = request.num_tokens
249
250
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
251
252

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
253
        self.block_table.add_row(request.block_ids, req_index)
254
255
256

        sampling_params = request.sampling_params
        if sampling_params.sampling_type == SamplingType.GREEDY:
257
258
            # Avoid later division by zero.
            self.temperature_cpu[req_index] = -1.0
259
260
            self.greedy_reqs.add(req_id)
        else:
261
            self.temperature_cpu[req_index] = sampling_params.temperature
262
263
264
265
266
267
268
269
            self.random_reqs.add(req_id)

        self.top_p_cpu[req_index] = sampling_params.top_p
        if sampling_params.top_p < 1:
            self.top_p_reqs.add(req_id)
        self.top_k_cpu[req_index] = sampling_params.top_k
        if sampling_params.top_k > 0:
            self.top_k_reqs.add(req_id)
270
        self.min_p_cpu[req_index] = sampling_params.min_p
271
272
        self.frequency_penalties_cpu[
            req_index] = sampling_params.frequency_penalty
273
274
        if sampling_params.min_p > _SAMPLING_EPS:
            self.min_p_reqs.add(req_id)
275
276
        if sampling_params.frequency_penalty != 0.0:
            self.frequency_penalties_reqs.add(req_id)
277
278
        self.presence_penalties_cpu[
            req_index] = sampling_params.presence_penalty
279
280
        if sampling_params.presence_penalty != 0.0:
            self.presence_penalties_reqs.add(req_id)
281
282
        self.repetition_penalties_cpu[
            req_index] = sampling_params.repetition_penalty
283
284
        if sampling_params.repetition_penalty != 1.0:
            self.repetition_penalties_reqs.add(req_id)
285
286
287
        if sampling_params.min_tokens:
            self.min_tokens[req_index] = (sampling_params.min_tokens,
                                          sampling_params.all_stop_token_ids)
288

289
290
291
292
        # NOTE(woosuk): self.generators should not include the requests that
        # do not have their own generator.
        if request.generator is not None:
            self.generators[req_index] = request.generator
293

294
295
296
297
        if sampling_params.logprobs is not None:
            self.num_logprobs[req_id] = sampling_params.logprobs
        if sampling_params.prompt_logprobs is not None:
            self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
298
299
        if sampling_params.logit_bias is not None:
            self.logit_bias[req_index] = sampling_params.logit_bias
300

301
302
303
304
        if sampling_params.allowed_token_ids:
            self.has_allowed_token_ids.add(req_id)
            if self.allowed_token_ids_mask_cpu_tensor is None:
                # Lazy allocation for this tensor, which can be large.
305
                # False means we don't fill with -inf.
306
307
308
309
310
311
312
313
314
                self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
                                                          self.vocab_size,
                                                          dtype=torch.bool,
                                                          device=self.device)
                self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                    self.max_num_reqs,
                    self.vocab_size,
                    dtype=torch.bool,
                    device="cpu")
315
316
            self.allowed_token_ids_mask_cpu_tensor[req_index] = True
            # False means we don't fill with -inf.
317
            self.allowed_token_ids_mask_cpu_tensor[req_index][
318
                sampling_params.allowed_token_ids] = False
319

320
321
322
323
324
325
326
327
328
329
330
331
332
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

333
    def remove_request(self, req_id: str) -> Optional[int]:
334
335
        """This method must always be followed by a call to condense()."""

336
337
338
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
339
340
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
341
342
343
344
345

        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
346
        self.min_p_reqs.discard(req_id)
347
        self.min_tokens.pop(req_index, None)
348
349
350
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
351
352
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
353
        self.num_prompt_logprobs.pop(req_id, None)
354
355
356
357
358
359
360
361
362
363

        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            self.lora_id_to_request_ids[lora_id].discard(req_id)
            if len(self.lora_id_to_request_ids[lora_id]) == 0:
                self.lora_id_to_request_ids.pop(lora_id)
                self.lora_id_to_lora_request.pop(lora_id)
            self.request_lora_mapping[req_index] = 0

364
        self.logit_bias[req_index] = None
365
366
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
367
            # False means we don't fill with -inf.
368
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
369
370
        return req_index

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
        self._req_ids[i1], self._req_ids[i2] =\
            self._req_ids[i2], self._req_ids[i1] # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
            self.req_output_token_ids[i2], self.req_output_token_ids[i1]
        assert old_id_i1 is not None and old_id_i2 is not None
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
            self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
        self.num_tokens[i1], self.num_tokens[i2] =\
            self.num_tokens[i2], self.num_tokens[i1]
        self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
            self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
            self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
            self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
            self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
        self.temperature_cpu[i1], self.temperature_cpu[i2] =\
            self.temperature_cpu[i2], self.temperature_cpu[i1]
        self.top_p_cpu[i1], self.top_p_cpu[i2] =\
            self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] =\
            self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
            self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
            self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
            self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
        self.min_p_cpu[i1], self.min_p_cpu[i2] =\
            self.min_p_cpu[i2], self.min_p_cpu[i1]

        g1 = self.generators.get(i1)
        g2 = self.generators.get(i2)
        if g1 is not None:
            self.generators[i2] = g1
        if g2 is not None:
            self.generators[i1] = g2

        t1 = self.min_tokens.get(i1)
        t2 = self.min_tokens.get(i2)
        if t1 is not None:
            self.min_tokens[i2] = t1
        if t2 is not None:
            self.min_tokens[i1] = t2

        self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
            self.request_lora_mapping[i2], self.request_lora_mapping[i1]
        self.logit_bias[i1], self.logit_bias[i2] =\
            self.logit_bias[i2], self.logit_bias[i1]
        self.block_table.swap_row(i1, i2)

426
    def condense(self, empty_req_indices: list[int]) -> None:
427
428
        num_reqs = self.num_reqs
        if num_reqs == 0:
429
            # The batched states are empty.
430
431
            self._req_ids.clear()
            self.req_output_token_ids.clear()
432
433
434
435
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
436
        last_req_index = num_reqs + len(empty_req_indices) - 1
437
438
439
440
441
442
443
444
445
446
447
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
            empty_index = empty_req_indices.pop()
            if empty_index >= last_req_index:
                break

            # Swap the states.
448
449
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
450
            assert req_id is not None
451
452
453
454
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
455
456
            self.req_id_to_index[req_id] = empty_index

457
458
459
460
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
                last_req_index, :num_tokens]
            self.num_tokens[empty_index] = num_tokens
461
462
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
                last_req_index]
463
464
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
                last_req_index]
465
466
            self.num_computed_tokens_cpu[
                empty_index] = self.num_computed_tokens_cpu[last_req_index]
467
            self.block_table.move_row(last_req_index, empty_index)
468
469
470
471
            self.temperature_cpu[empty_index] = self.temperature_cpu[
                last_req_index]
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
472
473
474
475
476
477
            self.frequency_penalties_cpu[
                empty_index] = self.frequency_penalties_cpu[last_req_index]
            self.presence_penalties_cpu[
                empty_index] = self.presence_penalties_cpu[last_req_index]
            self.repetition_penalties_cpu[
                empty_index] = self.repetition_penalties_cpu[last_req_index]
478
            self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
479
480
481
482
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

483
484
485
486
            min_token = self.min_tokens.pop(last_req_index, None)
            if min_token is not None:
                self.min_tokens[empty_index] = min_token

487
488
489
            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
                last_req_index]

490
491
            self.logit_bias[empty_index] = self.logit_bias[last_req_index]

492
493
494
495
496
            if self.allowed_token_ids_mask_cpu_tensor is not None:
                self.allowed_token_ids_mask_cpu_tensor[
                    empty_index] = self.allowed_token_ids_mask_cpu_tensor[
                        last_req_index]

497
498
499
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

500
501
502
503
504
505
506
507
508
        # Trim lists to the batch size.
        del self._req_ids[self.num_reqs:]
        del self.req_output_token_ids[self.num_reqs:]

    def refresh_sampling_metadata(self):
        self.sampling_metadata = self._make_sampling_metadata()

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
509
510
511
512
513
        if not self.all_greedy:
            temperature = copy_slice(self.temperature_cpu_tensor,
                                     self.temperature, num_reqs)
        else:
            temperature = None
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
        if not self.no_min_p:
            copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
            copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs)
            copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs)
            copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs)

            # The prompt tokens are used only for applying penalties during
            # the sampling process. Hence copy these tensors only when
            # there are requests which need penalties to be applied.
            prompt_token_ids = self._make_prompt_token_ids_tensor()
        else:
            prompt_token_ids = None
538

539
540
541
542
543
544
545
        allowed_token_ids_mask: Optional[torch.Tensor] = None
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
            copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs)
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

546
        return SamplingMetadata(
547
            temperature=temperature,
548
549
            all_greedy=self.all_greedy,
            all_random=self.all_random,
550
551
552
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
            min_p=None if self.no_min_p else self.min_p[:num_reqs],
553
554
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
555
556
557
558
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
559
            output_token_ids=cast(list[list[int]], self.req_output_token_ids),
560
            min_tokens=self.min_tokens,
561
            no_penalties=self.no_penalties,
562
            logit_bias=self.logit_bias[:num_reqs],
563
            allowed_token_ids_mask=allowed_token_ids_mask,
564
565
        )

566
567
568
569
570
571
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
        max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
572
573
            pin_memory=self.pin_memory,
        )
574
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
575
576
        prompt_token_ids[:] = self.token_ids_cpu[:self.
                                                 num_reqs, :max_prompt_len]
577
578
579
580
581
582
583
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
        for i in range(self.num_reqs):
            prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device,
                                              non_blocking=True)

584
585
    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray
586
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size self.num_reqs where,
               prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

        req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
        prompt_lora_mapping = tuple(req_lora_mapping)
        token_lora_mapping = tuple(
            req_lora_mapping.repeat(num_scheduled_tokens))
602
        active_lora_requests: set[LoRARequest] = set(
603
604
605
606
            self.lora_id_to_lora_request.values())

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

627
628
629
630
    @property
    def no_min_p(self) -> bool:
        return len(self.min_p_reqs) == 0

631
632
633
634
635
636
    @property
    def no_penalties(self) -> bool:
        return (len(self.presence_penalties_reqs) == 0
                and len(self.frequency_penalties_reqs) == 0
                and len(self.repetition_penalties_reqs) == 0)

637
    @property
638
639
    def max_num_logprobs(self) -> Optional[int]:
        return max(self.num_logprobs.values()) if self.num_logprobs else None
640
641
642

    @property
    def no_prompt_logprob(self) -> bool:
643
        return not self.num_prompt_logprobs
644
645
646
647

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0