gpu_input_batch.py 33.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import Optional, cast
7
8
9

import numpy as np
import torch
10
from typing_extensions import deprecated
11

12
from vllm.lora.request import LoRARequest
13
14
from vllm.multimodal.inputs import (MultiModalKwargsItem,
                                    MultiModalKwargsItems, PlaceholderRange)
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams, SamplingType
17
from vllm.utils import swap_dict_values
18
from vllm.v1.outputs import LogprobsTensors
19
from vllm.v1.pool.metadata import PoolingMetadata
20
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
21
22
                                             LogitsProcessors,
                                             MoveDirectionality)
23
from vllm.v1.sample.metadata import SamplingMetadata
24
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
25
from vllm.v1.utils import copy_slice
26
from vllm.v1.worker.block_table import MultiGroupBlockTable
27
28
29
30
31
32


@dataclass
class CachedRequestState:

    req_id: str
33
    prompt_token_ids: list[int]
34
    mm_kwargs: list[MultiModalKwargsItem]
35
    mm_positions: list[PlaceholderRange]
36
    mm_hashes: list[str]
37
38
    sampling_params: Optional[SamplingParams]
    pooling_params: Optional[PoolingParams]
39
40
    generator: Optional[torch.Generator]

41
    block_ids: tuple[list[int], ...]
42
    num_computed_tokens: int
43
    output_token_ids: list[int]
44

45
46
47
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[int] = None

48
49
    lora_request: Optional[LoRARequest] = None

50
51
52
    def __post_init__(self):
        self.num_prompt_tokens = len(self.prompt_token_ids)

53
54
    @property
    def num_tokens(self) -> int:
55
56
        return self.num_prompt_tokens + len(self.output_token_ids)

57
58
59
60
    # Temporary back-compatibility for plugins that define model runner
    @property
    @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
                "removed in v0.13. Please use `mm_kwargs` instead.")
61
62
63
64
    def mm_inputs(self) -> list[MultiModalKwargsItems]:
        return [
            MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
        ]
65

66
67
68
    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
            return self.prompt_token_ids[idx]
69
        return self.output_token_ids[idx - self.num_prompt_tokens]
70
71
72
73
74


class InputBatch:

    def __init__(
75
76
77
78
79
80
81
82
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
83
        logitsprocs: Optional[LogitsProcessors] = None,
84
        is_spec_decode: bool = False,
85
        is_pooling_model: bool = False,
86
    ):
87
        self.is_pooling_model = is_pooling_model
88
        self.is_spec_decode = is_spec_decode
89
90
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
91
        self.max_num_batched_tokens = max_num_batched_tokens
92
93
        self.device = device
        self.pin_memory = pin_memory
94
        self.vocab_size = vocab_size
95

96
97
        self._req_ids: list[Optional[str]] = []
        self.req_id_to_index: dict[str, int] = {}
98

99
100
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
101
102
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
103
104
105
106
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
107
            pin_memory=False,
108
109
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
110
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
111
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
112
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
113
114
115
116
117
118
119
120
        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs, ),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_computed_tokens_cpu = \
            self.num_computed_tokens_cpu_tensor.numpy()
121

122
        # Block table.
123
        self.block_table = MultiGroupBlockTable(
124
            max_num_reqs=max_num_reqs,
125
            max_model_len=max_model_len,
126
            max_num_batched_tokens=max_num_batched_tokens,
127
            pin_memory=pin_memory,
128
            device=device,
129
            block_sizes=block_sizes,
130
131
132
133
134
135
136
137
138
139
140
        )

        # Sampling-related.
        self.temperature = torch.empty((max_num_reqs, ),
                                       dtype=torch.float32,
                                       device=device)
        self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
                                                  dtype=torch.float32,
                                                  device="cpu",
                                                  pin_memory=pin_memory)
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
141
142
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
143
144
145
146
147
148
149
150
151

        self.top_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
152
        self.top_p_reqs: set[str] = set()
153
154
155
156
157
158
159
160
161

        self.top_k = torch.empty((max_num_reqs, ),
                                 dtype=torch.int32,
                                 device=device)
        self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
162
        self.top_k_reqs: set[str] = set()
163

164
165
        # IDs of requests which do not support spec decoding
        self.spec_decode_unsupported_reqs: set[str] = set()
166

167
168
169
170
171
172
173
174
175
176
        # Frequency penalty related data structures
        self.frequency_penalties = torch.empty((max_num_reqs, ),
                                               dtype=torch.float,
                                               device=device)
        self.frequency_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.frequency_penalties_cpu = \
177
            self.frequency_penalties_cpu_tensor.numpy()
178
        self.frequency_penalties_reqs: set[str] = set()
179
180
181
182
183
184
185
186
187

        # Presence penalty related data structures
        self.presence_penalties = torch.empty((max_num_reqs, ),
                                              dtype=torch.float,
                                              device=device)
        self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
                                                         dtype=torch.float,
                                                         device="cpu",
                                                         pin_memory=pin_memory)
188
189
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
        )
190
        self.presence_penalties_reqs: set[str] = set()
191
192
193
194
195
196
197
198
199
200
201

        # Repetition penalty related data structures
        self.repetition_penalties = torch.empty((max_num_reqs, ),
                                                dtype=torch.float,
                                                device=device)
        self.repetition_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.repetition_penalties_cpu = \
202
            self.repetition_penalties_cpu_tensor.numpy()
203
        self.repetition_penalties_reqs: set[str] = set()
204

205
206
207
        # lora related
        self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
                                             dtype=np.int32)
208
209
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
210

211
        # req_index -> generator
212
213
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
214
        self.generators: dict[int, torch.Generator] = {}
215

216
        self.num_logprobs: dict[str, int] = {}
217
218
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
219
        self.num_prompt_logprobs: dict[str, int] = {}
220

221
222
223
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

224
225
226
227
228
229
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
230
        self.has_allowed_token_ids: set[str] = set()
231
232
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
233
234
        self.allowed_token_ids_mask: Optional[torch.Tensor] = None
        self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
235

236
237
238
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

239
240
241
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
                                                          dtype=bool)

242
        self.req_output_token_ids: list[Optional[list[int]]] = []
243

244
245
246
247
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()

248
249
250
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

251
252
        self.pooling_params: dict[str, PoolingParams] = {}

253
    @property
254
    def req_ids(self) -> list[str]:
255
256
        # None elements should only be present transiently
        # while performing state updates to the batch.
257
        return cast(list[str], self._req_ids)
258

259
    def _register_add_request(self, request: "CachedRequestState") -> int:
260
261
262
263
264
265
266
267
268
269
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
270
271
272
273
274
275
276
277
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
                (new_req_index, request.sampling_params,
                 request.prompt_token_ids, request.output_token_ids))

278
        return new_req_index
279

280
281
282
    def add_request(
        self,
        request: "CachedRequestState",
283
    ) -> int:
284
        req_index = self._register_add_request(request)
285
286

        req_id = request.req_id
287
288
289
290
291
292
293
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids

294
295
296
297
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
        num_prompt_tokens = len(request.prompt_token_ids)
298
        self.num_prompt_tokens[req_index] = num_prompt_tokens
299
300
301
302
303
304
        self.token_ids_cpu[
            req_index, :num_prompt_tokens] = request.prompt_token_ids
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
        self.token_ids_cpu[req_index,
                           start_idx:end_idx] = request.output_token_ids
305
306
        # Number of token ids in token_ids_cpu.
        # NOTE(woosuk): This may include spec decode tokens.
307
        self.num_tokens[req_index] = request.num_tokens
308
309
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
310
311

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
312
        self.block_table.add_row(request.block_ids, req_index)
313

314
        if sampling_params := request.sampling_params:
315
316
317
            if (self.is_spec_decode
                    and is_spec_decode_unsupported(sampling_params)):
                self.spec_decode_unsupported_reqs.add(req_id)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            if sampling_params.sampling_type == SamplingType.GREEDY:
                # Avoid later division by zero.
                self.temperature_cpu[req_index] = -1.0
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
            self.frequency_penalties_cpu[
                req_index] = sampling_params.frequency_penalty
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
            self.presence_penalties_cpu[
                req_index] = sampling_params.presence_penalty
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
            self.repetition_penalties_cpu[
                req_index] = sampling_params.repetition_penalty
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
354
355
356
                self.num_logprobs[req_id] = (self.vocab_size
                                             if sampling_params.logprobs == -1
                                             else sampling_params.logprobs)
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            if sampling_params.prompt_logprobs is not None:
                self.num_prompt_logprobs[
                    req_id] = sampling_params.prompt_logprobs

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device=self.device)
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device="cpu")
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
377
                # False means we don't fill with -inf.
378
379
                self.allowed_token_ids_mask_cpu_tensor[req_index][
                    sampling_params.allowed_token_ids] = False
380

381
382
383
            if sampling_params.bad_words_token_ids:
                self.bad_words_token_ids[
                    req_index] = sampling_params.bad_words_token_ids
384
385
386
387
        elif pooling_params := request.pooling_params:
            self.pooling_params[req_id] = pooling_params
            self.logits_processing_needs_token_ids[req_index] = (
                pooling_params.requires_token_ids)
388
        else:
389
            raise NotImplementedError("Unrecognized request type")
390

391
392
393
394
395
396
397
398
399
400
401
402
403
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

404
405
        return req_index

406
    def remove_request(self, req_id: str) -> Optional[int]:
407
        """This method must always be followed by a call to condense().
408

409
410
411
412
413
414
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
415

416
417
418
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
419
420

        self.batch_update_builder.removed_append(req_index)
421
422
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
423

424
425
426
427
428
429
430
431
432
433
434
435
436
437
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
            return req_index

438
439
440
441
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
442
        self.spec_decode_unsupported_reqs.discard(req_id)
443
444
445
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
446
447
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
448
        self.num_prompt_logprobs.pop(req_id, None)
449
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
450

451
452
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
453
            # False means we don't fill with -inf.
454
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
455
        self.bad_words_token_ids.pop(req_index, None)
456
457
        return req_index

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
        self._req_ids[i1], self._req_ids[i2] =\
            self._req_ids[i2], self._req_ids[i1] # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
            self.req_output_token_ids[i2], self.req_output_token_ids[i1]
        assert old_id_i1 is not None and old_id_i2 is not None
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
            self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
        self.num_tokens[i1], self.num_tokens[i2] =\
            self.num_tokens[i2], self.num_tokens[i1]
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
            self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
            self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
            self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]

477
478
479
480
481
482
483
484
485
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporiarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

486
        self.block_table.swap_row(i1, i2)
487

488
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
489
            self.request_lora_mapping[i2], self.request_lora_mapping[i1]
490

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
        self.batch_update_builder.moved.append(
            (i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = \
            self.temperature_cpu[i2], self.temperature_cpu[i1]
        self.top_p_cpu[i1], self.top_p_cpu[i2] = \
            self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = \
            self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
            self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
            self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
            self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

516
517
518
519
520
        if self.allowed_token_ids_mask_cpu_tensor is not None:
            self.allowed_token_ids_mask_cpu_tensor[i1], \
                self.allowed_token_ids_mask_cpu_tensor[i2] =\
                self.allowed_token_ids_mask_cpu_tensor[i2], \
                    self.allowed_token_ids_mask_cpu_tensor[i1]
521

522
523
524
525
526
527
528
529
530
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
531
        """
532
533
        num_reqs = self.num_reqs

534
535
536
537
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
538
        if num_reqs == 0:
539
            # The batched states are empty.
540
541
            self._req_ids.clear()
            self.req_output_token_ids.clear()
542
543
544
545
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
546
        last_req_index = num_reqs + len(empty_req_indices) - 1
547
548
549
550
551
552
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
553
554
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
555
556
557
            if empty_index >= last_req_index:
                break

558
559
560
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
561
562
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
563
            assert req_id is not None
564
565
566
567
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
568
569
            self.req_id_to_index[req_id] = empty_index

570
571
572
573
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
                last_req_index, :num_tokens]
            self.num_tokens[empty_index] = num_tokens
574
575
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
                last_req_index]
576
577
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
                last_req_index]
578
579
            self.num_computed_tokens_cpu[
                empty_index] = self.num_computed_tokens_cpu[last_req_index]
580
            self.block_table.move_row(last_req_index, empty_index)
581
582
583
584
585
586

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
                last_req_index]

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
587
                # Sampling state not used by pooling models.
588
589
590
591
592
593
594
595
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
                (last_req_index, empty_index,
                 MoveDirectionality.UNIDIRECTIONAL))

596
597
598
599
            self.temperature_cpu[empty_index] = self.temperature_cpu[
                last_req_index]
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
600
601
602
603
604
605
            self.frequency_penalties_cpu[
                empty_index] = self.frequency_penalties_cpu[last_req_index]
            self.presence_penalties_cpu[
                empty_index] = self.presence_penalties_cpu[last_req_index]
            self.repetition_penalties_cpu[
                empty_index] = self.repetition_penalties_cpu[last_req_index]
606
607
608
609
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

610
            # TODO convert these to LogitsProcessors
611
612
613
614
615
            if self.allowed_token_ids_mask_cpu_tensor is not None:
                self.allowed_token_ids_mask_cpu_tensor[
                    empty_index] = self.allowed_token_ids_mask_cpu_tensor[
                        last_req_index]

616
617
618
619
            bad_words_token_ids = self.bad_words_token_ids.pop(
                last_req_index, None)
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
620

621
622
623
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

624
        # Trim lists to the batch size.
625
626
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
627

628
    def refresh_metadata(self):
629
        """Apply any batch updates to sampling metadata."""
630

631
        if self.is_pooling_model:
632
633
634
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
635
636
637
638
639
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
640
641
642
643
644
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
645
646
647

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
648
649
650
651
652
        if not self.all_greedy:
            temperature = copy_slice(self.temperature_cpu_tensor,
                                     self.temperature, num_reqs)
        else:
            temperature = None
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
            copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs)
            copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs)
            copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs)

669
670
671
        needs_prompt_token_ids = (
            not self.no_penalties
            or self.logits_processing_needs_token_ids[:num_reqs].any())
672
673
674
675
676
        if needs_prompt_token_ids:
            # The prompt tokens are used only for applying penalties or
            # step pooling during the sampling/pooling process.
            # Hence copy these tensors only when there are requests which
            # need penalties/step_pooler to be applied.
677
678
679
            prompt_token_ids = self._make_prompt_token_ids_tensor()
        else:
            prompt_token_ids = None
680

681
682
683
684
685
686
687
        allowed_token_ids_mask: Optional[torch.Tensor] = None
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
            copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs)
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

688
        return SamplingMetadata(
689
            temperature=temperature,
690
691
            all_greedy=self.all_greedy,
            all_random=self.all_random,
692
693
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
694
695
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
696
697
698
699
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
700
            output_token_ids=cast(list[list[int]], self.req_output_token_ids),
701
            no_penalties=self.no_penalties,
702
            allowed_token_ids_mask=allowed_token_ids_mask,
703
            bad_words_token_ids=self.bad_words_token_ids,
704
            logitsprocs=self.logitsprocs,
705
706
        )

707
708
709
710
711
712
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
713
714
715

        return PoolingMetadata(
            prompt_lens=torch.from_numpy(
716
                self.num_prompt_tokens[:self.num_reqs]),
717
718
719
720
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
        )

721
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
722
723
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
724
725
726
727
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
728
729
            pin_memory=self.pin_memory,
        )
730
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
731
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
732
733
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
734
        for i in range(num_reqs):
735
736
737
738
            prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device,
                                              non_blocking=True)

739
740
    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray
741
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size self.num_reqs where,
               prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

        req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
        prompt_lora_mapping = tuple(req_lora_mapping)
        token_lora_mapping = tuple(
            req_lora_mapping.repeat(num_scheduled_tokens))
757
        active_lora_requests: set[LoRARequest] = set(
758
759
760
761
            self.lora_id_to_lora_request.values())

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

782
783
784
785
786
787
    @property
    def no_penalties(self) -> bool:
        return (len(self.presence_penalties_reqs) == 0
                and len(self.frequency_penalties_reqs) == 0
                and len(self.repetition_penalties_reqs) == 0)

788
    @property
789
790
    def max_num_logprobs(self) -> Optional[int]:
        return max(self.num_logprobs.values()) if self.num_logprobs else None
791
792
793

    @property
    def no_prompt_logprob(self) -> bool:
794
        return not self.num_prompt_logprobs
795
796
797
798

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0