gpu_input_batch.py 29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
# Datastructures defining an input batch

from dataclasses import dataclass
5
from typing import Optional, cast
6
7
8
9

import numpy as np
import torch

10
from vllm.lora.request import LoRARequest
11
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
12
from vllm.sampling_params import SamplingParams, SamplingType
13
from vllm.utils import swap_dict_values
14
from vllm.v1.outputs import LogprobsTensors
15
from vllm.v1.sample.metadata import SamplingMetadata
16
from vllm.v1.utils import copy_slice
17
from vllm.v1.worker.block_table import BlockTable
18

19
20
_SAMPLING_EPS = 1e-5

21
22
23
24
25

@dataclass
class CachedRequestState:

    req_id: str
26
27
    prompt_token_ids: list[int]
    mm_inputs: list[MultiModalKwargs]
28
    mm_positions: list[PlaceholderRange]
29
30
31
    sampling_params: SamplingParams
    generator: Optional[torch.Generator]

32
    block_ids: list[int]
33
    num_computed_tokens: int
34
    output_token_ids: list[int]
35

36
37
38
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[int] = None

39
40
    lora_request: Optional[LoRARequest] = None

41
42
43
    def __post_init__(self):
        self.num_prompt_tokens = len(self.prompt_token_ids)

44
45
    @property
    def num_tokens(self) -> int:
46
47
48
49
50
51
52
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
            return self.prompt_token_ids[idx]
        else:
            return self.output_token_ids[idx - self.num_prompt_tokens]
53
54
55
56
57
58
59
60
61
62
63


class InputBatch:

    def __init__(
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_blocks_per_req: int,
        device: torch.device,
        pin_memory: bool,
64
        vocab_size: int,
65
66
67
68
69
70
    ):
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
        self.max_num_blocks_per_req = max_num_blocks_per_req
        self.device = device
        self.pin_memory = pin_memory
71
        self.vocab_size = vocab_size
72

73
74
        self._req_ids: list[Optional[str]] = []
        self.req_id_to_index: dict[str, int] = {}
75

76
77
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
78
79
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
80
81
82
83
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
84
            pin_memory=False,
85
86
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
87
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
88
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
89
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
90
91
92
93
94
95
96
97
        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs, ),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_computed_tokens_cpu = \
            self.num_computed_tokens_cpu_tensor.numpy()
98

99
100
101
102
        # Block table.
        self.block_table = BlockTable(
            max_num_reqs=max_num_reqs,
            max_num_blocks_per_req=max_num_blocks_per_req,
103
            pin_memory=pin_memory,
104
            device=device,
105
106
107
108
109
110
111
112
113
114
115
        )

        # Sampling-related.
        self.temperature = torch.empty((max_num_reqs, ),
                                       dtype=torch.float32,
                                       device=device)
        self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
                                                  dtype=torch.float32,
                                                  device="cpu",
                                                  pin_memory=pin_memory)
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
116
117
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
118
119
120
121
122
123
124
125
126

        self.top_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
127
        self.top_p_reqs: set[str] = set()
128
129
130
131
132
133
134
135
136

        self.top_k = torch.empty((max_num_reqs, ),
                                 dtype=torch.int32,
                                 device=device)
        self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
137
        self.top_k_reqs: set[str] = set()
138

139
140
141
142
143
144
145
146
        self.min_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.min_p_cpu = self.min_p_cpu_tensor.numpy()
147
        self.min_p_reqs: set[str] = set()
148

149
150
151
152
153
154
155
156
157
158
        # Frequency penalty related data structures
        self.frequency_penalties = torch.empty((max_num_reqs, ),
                                               dtype=torch.float,
                                               device=device)
        self.frequency_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.frequency_penalties_cpu = \
159
            self.frequency_penalties_cpu_tensor.numpy()
160
        self.frequency_penalties_reqs: set[str] = set()
161
162
163
164
165
166
167
168
169

        # Presence penalty related data structures
        self.presence_penalties = torch.empty((max_num_reqs, ),
                                              dtype=torch.float,
                                              device=device)
        self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
                                                         dtype=torch.float,
                                                         device="cpu",
                                                         pin_memory=pin_memory)
170
171
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
        )
172
        self.presence_penalties_reqs: set[str] = set()
173
174
175
176
177
178
179
180
181
182
183

        # Repetition penalty related data structures
        self.repetition_penalties = torch.empty((max_num_reqs, ),
                                                dtype=torch.float,
                                                device=device)
        self.repetition_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.repetition_penalties_cpu = \
184
            self.repetition_penalties_cpu_tensor.numpy()
185
        self.repetition_penalties_reqs: set[str] = set()
186

187
        # req_index -> (min_tokens, stop_token_ids)
188
        self.min_tokens: dict[int, tuple[int, set[int]]] = {}
189

190
191
192
        # lora related
        self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
                                             dtype=np.int32)
193
194
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
195

196
        # req_index -> generator
197
198
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
199
        self.generators: dict[int, torch.Generator] = {}
200

201
        self.num_logprobs: dict[str, int] = {}
202
203
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
204
        self.num_prompt_logprobs: dict[str, int] = {}
205

206
207
208
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

209
        self.logit_bias: list[Optional[dict[int,
210
                                            float]]] = [None] * max_num_reqs
211
        self.has_allowed_token_ids: set[str] = set()
212
213
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
214
215
        self.allowed_token_ids_mask: Optional[torch.Tensor] = None
        self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
216

217
218
219
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

220
        self.req_output_token_ids: list[Optional[list[int]]] = []
221
222
223
224
225

        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

    @property
226
    def req_ids(self) -> list[str]:
227
228
        # None elements should only be present transiently
        # while performing state updates to the batch.
229
        return cast(list[str], self._req_ids)
230

231
232
233
234
235
236
237
238
239
240
    def add_request(
        self,
        request: "CachedRequestState",
        req_index: Optional[int] = None,
    ) -> None:
        if req_index is None:
            req_index = self.num_reqs
        assert req_index < self.max_num_reqs

        req_id = request.req_id
241
242
243
244
245
246
247
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids

248
249
250
251
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
        num_prompt_tokens = len(request.prompt_token_ids)
252
        self.num_prompt_tokens[req_index] = num_prompt_tokens
253
254
255
256
257
258
        self.token_ids_cpu[
            req_index, :num_prompt_tokens] = request.prompt_token_ids
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
        self.token_ids_cpu[req_index,
                           start_idx:end_idx] = request.output_token_ids
259
260
        # Number of token ids in token_ids_cpu.
        # NOTE(woosuk): This may include spec decode tokens.
261
        self.num_tokens[req_index] = request.num_tokens
262
263
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
264
265

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
266
        self.block_table.add_row(request.block_ids, req_index)
267
268
269

        sampling_params = request.sampling_params
        if sampling_params.sampling_type == SamplingType.GREEDY:
270
271
            # Avoid later division by zero.
            self.temperature_cpu[req_index] = -1.0
272
273
            self.greedy_reqs.add(req_id)
        else:
274
            self.temperature_cpu[req_index] = sampling_params.temperature
275
276
277
278
279
            self.random_reqs.add(req_id)

        self.top_p_cpu[req_index] = sampling_params.top_p
        if sampling_params.top_p < 1:
            self.top_p_reqs.add(req_id)
280
281
        top_k = sampling_params.top_k
        if 0 < top_k < self.vocab_size:
282
            self.top_k_reqs.add(req_id)
283
284
285
        else:
            top_k = self.vocab_size
        self.top_k_cpu[req_index] = top_k
286
        self.min_p_cpu[req_index] = sampling_params.min_p
287
288
        self.frequency_penalties_cpu[
            req_index] = sampling_params.frequency_penalty
289
290
        if sampling_params.min_p > _SAMPLING_EPS:
            self.min_p_reqs.add(req_id)
291
292
        if sampling_params.frequency_penalty != 0.0:
            self.frequency_penalties_reqs.add(req_id)
293
294
        self.presence_penalties_cpu[
            req_index] = sampling_params.presence_penalty
295
296
        if sampling_params.presence_penalty != 0.0:
            self.presence_penalties_reqs.add(req_id)
297
298
        self.repetition_penalties_cpu[
            req_index] = sampling_params.repetition_penalty
299
300
        if sampling_params.repetition_penalty != 1.0:
            self.repetition_penalties_reqs.add(req_id)
301
302
303
        if sampling_params.min_tokens:
            self.min_tokens[req_index] = (sampling_params.min_tokens,
                                          sampling_params.all_stop_token_ids)
304

305
306
307
308
        # NOTE(woosuk): self.generators should not include the requests that
        # do not have their own generator.
        if request.generator is not None:
            self.generators[req_index] = request.generator
309

310
311
312
313
        if sampling_params.logprobs is not None:
            self.num_logprobs[req_id] = sampling_params.logprobs
        if sampling_params.prompt_logprobs is not None:
            self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
314
315
        if sampling_params.logit_bias is not None:
            self.logit_bias[req_index] = sampling_params.logit_bias
316

317
318
319
320
        if sampling_params.allowed_token_ids:
            self.has_allowed_token_ids.add(req_id)
            if self.allowed_token_ids_mask_cpu_tensor is None:
                # Lazy allocation for this tensor, which can be large.
321
                # False means we don't fill with -inf.
322
323
324
325
326
327
328
329
330
                self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
                                                          self.vocab_size,
                                                          dtype=torch.bool,
                                                          device=self.device)
                self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                    self.max_num_reqs,
                    self.vocab_size,
                    dtype=torch.bool,
                    device="cpu")
331
332
            self.allowed_token_ids_mask_cpu_tensor[req_index] = True
            # False means we don't fill with -inf.
333
            self.allowed_token_ids_mask_cpu_tensor[req_index][
334
                sampling_params.allowed_token_ids] = False
335

336
337
338
        if sampling_params.bad_words_token_ids:
            self.bad_words_token_ids[
                req_index] = sampling_params.bad_words_token_ids
339

340
341
342
343
344
345
346
347
348
349
350
351
352
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

353
    def remove_request(self, req_id: str) -> Optional[int]:
354
355
        """This method must always be followed by a call to condense()."""

356
357
358
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
359
360
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
361
362
363
364
365

        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
366
        self.min_p_reqs.discard(req_id)
367
        self.min_tokens.pop(req_index, None)
368
369
370
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
371
372
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
373
        self.num_prompt_logprobs.pop(req_id, None)
374
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
375
376
377
378
379
380
381
382
383
384

        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            self.lora_id_to_request_ids[lora_id].discard(req_id)
            if len(self.lora_id_to_request_ids[lora_id]) == 0:
                self.lora_id_to_request_ids.pop(lora_id)
                self.lora_id_to_lora_request.pop(lora_id)
            self.request_lora_mapping[req_index] = 0

385
        self.logit_bias[req_index] = None
386
387
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
388
            # False means we don't fill with -inf.
389
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
390
        self.bad_words_token_ids.pop(req_index, None)
391
392
        return req_index

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
        self._req_ids[i1], self._req_ids[i2] =\
            self._req_ids[i2], self._req_ids[i1] # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
            self.req_output_token_ids[i2], self.req_output_token_ids[i1]
        assert old_id_i1 is not None and old_id_i2 is not None
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
            self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
        self.num_tokens[i1], self.num_tokens[i2] =\
            self.num_tokens[i2], self.num_tokens[i1]
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
            self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
            self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
            self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
        self.temperature_cpu[i1], self.temperature_cpu[i2] =\
            self.temperature_cpu[i2], self.temperature_cpu[i1]
        self.top_p_cpu[i1], self.top_p_cpu[i2] =\
            self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] =\
            self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
            self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
            self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
            self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
        self.min_p_cpu[i1], self.min_p_cpu[i2] =\
            self.min_p_cpu[i2], self.min_p_cpu[i1]

426
427
428
429
430
431
432
433
434
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporiarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

435
436
437
        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.min_tokens, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)
438
439
440
441
442

        self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
            self.request_lora_mapping[i2], self.request_lora_mapping[i1]
        self.logit_bias[i1], self.logit_bias[i2] =\
            self.logit_bias[i2], self.logit_bias[i1]
443
444
445
446
447
448

        if self.allowed_token_ids_mask_cpu_tensor is not None:
            self.allowed_token_ids_mask_cpu_tensor[i1], \
                self.allowed_token_ids_mask_cpu_tensor[i2] =\
                self.allowed_token_ids_mask_cpu_tensor[i2], \
                    self.allowed_token_ids_mask_cpu_tensor[i1]
449
450
        self.block_table.swap_row(i1, i2)

451
    def condense(self, empty_req_indices: list[int]) -> None:
452
453
        num_reqs = self.num_reqs
        if num_reqs == 0:
454
            # The batched states are empty.
455
456
            self._req_ids.clear()
            self.req_output_token_ids.clear()
457
458
459
460
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
461
        last_req_index = num_reqs + len(empty_req_indices) - 1
462
463
464
465
466
467
468
469
470
471
472
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
            empty_index = empty_req_indices.pop()
            if empty_index >= last_req_index:
                break

            # Swap the states.
473
474
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
475
            assert req_id is not None
476
477
478
479
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
480
481
            self.req_id_to_index[req_id] = empty_index

482
483
484
485
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
                last_req_index, :num_tokens]
            self.num_tokens[empty_index] = num_tokens
486
487
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
                last_req_index]
488
489
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
                last_req_index]
490
491
            self.num_computed_tokens_cpu[
                empty_index] = self.num_computed_tokens_cpu[last_req_index]
492
            self.block_table.move_row(last_req_index, empty_index)
493
494
495
496
            self.temperature_cpu[empty_index] = self.temperature_cpu[
                last_req_index]
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
497
498
499
500
501
502
            self.frequency_penalties_cpu[
                empty_index] = self.frequency_penalties_cpu[last_req_index]
            self.presence_penalties_cpu[
                empty_index] = self.presence_penalties_cpu[last_req_index]
            self.repetition_penalties_cpu[
                empty_index] = self.repetition_penalties_cpu[last_req_index]
503
            self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
504
505
506
507
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

508
509
510
511
            min_token = self.min_tokens.pop(last_req_index, None)
            if min_token is not None:
                self.min_tokens[empty_index] = min_token

512
513
514
            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
                last_req_index]

515
516
            self.logit_bias[empty_index] = self.logit_bias[last_req_index]

517
518
519
520
521
            if self.allowed_token_ids_mask_cpu_tensor is not None:
                self.allowed_token_ids_mask_cpu_tensor[
                    empty_index] = self.allowed_token_ids_mask_cpu_tensor[
                        last_req_index]

522
523
524
525
            bad_words_token_ids = self.bad_words_token_ids.pop(
                last_req_index, None)
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
526
527
528
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

529
530
531
532
533
534
535
536
537
        # Trim lists to the batch size.
        del self._req_ids[self.num_reqs:]
        del self.req_output_token_ids[self.num_reqs:]

    def refresh_sampling_metadata(self):
        self.sampling_metadata = self._make_sampling_metadata()

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
538
539
540
541
542
        if not self.all_greedy:
            temperature = copy_slice(self.temperature_cpu_tensor,
                                     self.temperature, num_reqs)
        else:
            temperature = None
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
        if not self.no_min_p:
            copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
            copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs)
            copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs)
            copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs)

            # The prompt tokens are used only for applying penalties during
            # the sampling process. Hence copy these tensors only when
            # there are requests which need penalties to be applied.
            prompt_token_ids = self._make_prompt_token_ids_tensor()
        else:
            prompt_token_ids = None
567

568
569
570
571
572
573
574
        allowed_token_ids_mask: Optional[torch.Tensor] = None
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
            copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs)
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

575
        return SamplingMetadata(
576
            temperature=temperature,
577
578
            all_greedy=self.all_greedy,
            all_random=self.all_random,
579
580
581
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
            min_p=None if self.no_min_p else self.min_p[:num_reqs],
582
583
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
584
585
586
587
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
588
            output_token_ids=cast(list[list[int]], self.req_output_token_ids),
589
            min_tokens=self.min_tokens,
590
            no_penalties=self.no_penalties,
591
            logit_bias=self.logit_bias[:num_reqs],
592
            allowed_token_ids_mask=allowed_token_ids_mask,
593
            bad_words_token_ids=self.bad_words_token_ids,
594
595
        )

596
597
598
599
600
601
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
        max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
602
603
            pin_memory=self.pin_memory,
        )
604
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
605
606
        prompt_token_ids[:] = self.token_ids_cpu[:self.
                                                 num_reqs, :max_prompt_len]
607
608
609
610
611
612
613
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
        for i in range(self.num_reqs):
            prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device,
                                              non_blocking=True)

614
615
    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray
616
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size self.num_reqs where,
               prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

        req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
        prompt_lora_mapping = tuple(req_lora_mapping)
        token_lora_mapping = tuple(
            req_lora_mapping.repeat(num_scheduled_tokens))
632
        active_lora_requests: set[LoRARequest] = set(
633
634
635
636
            self.lora_id_to_lora_request.values())

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

657
658
659
660
    @property
    def no_min_p(self) -> bool:
        return len(self.min_p_reqs) == 0

661
662
663
664
665
666
    @property
    def no_penalties(self) -> bool:
        return (len(self.presence_penalties_reqs) == 0
                and len(self.frequency_penalties_reqs) == 0
                and len(self.repetition_penalties_reqs) == 0)

667
    @property
668
669
    def max_num_logprobs(self) -> Optional[int]:
        return max(self.num_logprobs.values()) if self.num_logprobs else None
670
671
672

    @property
    def no_prompt_logprob(self) -> bool:
673
        return not self.num_prompt_logprobs
674
675
676
677

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0