gpu_input_batch.py 39.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4

5
from dataclasses import dataclass, field
6
from typing import Optional, cast
7
8
9

import numpy as np
import torch
10
from typing_extensions import deprecated
11

12
from vllm import envs
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams, SamplingType
17
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
18
from vllm.v1.outputs import LogprobsTensors
19
from vllm.v1.pool.metadata import PoolingMetadata
20
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
21
22
                                             LogitsProcessors,
                                             MoveDirectionality)
23
from vllm.v1.sample.metadata import SamplingMetadata
24
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
25
from vllm.v1.utils import copy_slice
26
from vllm.v1.worker.block_table import MultiGroupBlockTable
27
28
29
30
31
32


@dataclass
class CachedRequestState:

    req_id: str
33
    prompt_token_ids: Optional[list[int]]
34
    mm_features: list[MultiModalFeatureSpec]
35
36
    sampling_params: Optional[SamplingParams]
    pooling_params: Optional[PoolingParams]
37
38
    generator: Optional[torch.Generator]

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
zhuwenwen's avatar
zhuwenwen committed
42
    spec_token_ids: list[int] = None
43

44
45
46
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[int] = None

47
    lora_request: Optional[LoRARequest] = None
48
    prompt_embeds: Optional[torch.Tensor] = None
49
50
51
52
53
    # Lazily populated when `VLLM_V1_FAST_TOKEN_ID_COPY` is enabled.
    _prompt_token_ids_np: Optional[np.ndarray] = field(default=None,
                                                       init=False,
                                                       repr=False,
                                                       compare=False)
54

55
    def __post_init__(self):
56
57
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
            self.prompt_token_ids, self.prompt_embeds)
58

59
60
    @property
    def num_tokens(self) -> int:
61
62
        return self.num_prompt_tokens + len(self.output_token_ids)

63
64
65
66
    # Temporary back-compatibility for plugins that define model runner
    @property
    @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
                "removed in v0.13. Please use `mm_kwargs` instead.")
67
68
    def mm_inputs(self) -> list[MultiModalKwargsItems]:
        return [
69
70
            MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
            if f.data is not None
71
        ]
72

73
74
    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
75
76
77
78
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
                    "provided via prompt_embeds, and its ID is unknown.")
79
            return self.prompt_token_ids[idx]
80
81
82
83
        elif idx - self.num_prompt_tokens < len(self.output_token_ids):
            return self.output_token_ids[idx - self.num_prompt_tokens]
        else:
            return -1
84
85
86
87
88


class InputBatch:

    def __init__(
89
90
91
92
93
94
95
96
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
97
        logitsprocs: Optional[LogitsProcessors] = None,
98
        is_spec_decode: bool = False,
99
        is_pooling_model: bool = False,
100
        num_speculative_tokens: int = 0,
101
    ):
102
        self.is_pooling_model = is_pooling_model
103
        self.is_spec_decode = is_spec_decode
104
105
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
106
        self.max_num_batched_tokens = max_num_batched_tokens
107
108
        self.device = device
        self.pin_memory = pin_memory
109
        self.vocab_size = vocab_size
110

111
112
        self._req_ids: list[Optional[str]] = []
        self.req_id_to_index: dict[str, int] = {}
113

114
115
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
116
117
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
118
119
120
121
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
122
            pin_memory=False,
123
124
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
125
126
127
128
129
130
131
132
        self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
                                        device="cpu",
                                        dtype=bool,
                                        pin_memory=False)
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
133
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
134
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
135
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
136
137
138
139
140
141
142
143
        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs, ),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_computed_tokens_cpu = \
            self.num_computed_tokens_cpu_tensor.numpy()
144

145
        # Block table.
146
        self.block_table = MultiGroupBlockTable(
147
            max_num_reqs=max_num_reqs,
148
            max_model_len=max_model_len,
149
            max_num_batched_tokens=max_num_batched_tokens,
150
            pin_memory=pin_memory,
151
            device=device,
152
            block_sizes=block_sizes,
153
            num_speculative_tokens=num_speculative_tokens,
154
155
156
157
158
159
160
161
162
163
164
        )

        # Sampling-related.
        self.temperature = torch.empty((max_num_reqs, ),
                                       dtype=torch.float32,
                                       device=device)
        self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
                                                  dtype=torch.float32,
                                                  device="cpu",
                                                  pin_memory=pin_memory)
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
165
166
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
167
168
169
170
171
172
173
174
175

        self.top_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
176
        self.top_p_reqs: set[str] = set()
177
178
179
180
181
182
183
184
185

        self.top_k = torch.empty((max_num_reqs, ),
                                 dtype=torch.int32,
                                 device=device)
        self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
186
        self.top_k_reqs: set[str] = set()
187

188
189
        # IDs of requests which do not support spec decoding
        self.spec_decode_unsupported_reqs: set[str] = set()
190

191
192
193
194
195
196
197
198
199
200
        # Frequency penalty related data structures
        self.frequency_penalties = torch.empty((max_num_reqs, ),
                                               dtype=torch.float,
                                               device=device)
        self.frequency_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.frequency_penalties_cpu = \
201
            self.frequency_penalties_cpu_tensor.numpy()
202
        self.frequency_penalties_reqs: set[str] = set()
203
204
205
206
207
208
209
210
211

        # Presence penalty related data structures
        self.presence_penalties = torch.empty((max_num_reqs, ),
                                              dtype=torch.float,
                                              device=device)
        self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
                                                         dtype=torch.float,
                                                         device="cpu",
                                                         pin_memory=pin_memory)
212
213
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
        )
214
        self.presence_penalties_reqs: set[str] = set()
215
216
217
218
219
220
221
222
223
224
225

        # Repetition penalty related data structures
        self.repetition_penalties = torch.empty((max_num_reqs, ),
                                                dtype=torch.float,
                                                device=device)
        self.repetition_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.repetition_penalties_cpu = \
226
            self.repetition_penalties_cpu_tensor.numpy()
227
        self.repetition_penalties_reqs: set[str] = set()
228

229
230
231
232
233
234
235
236
        # Speculative decoding
        self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
                                                         dtype=torch.int64,
                                                         device="cpu",
                                                         pin_memory=pin_memory)
        self.num_accepted_tokens_cpu = \
            self.num_accepted_tokens_cpu_tensor.numpy()

237
238
239
        # lora related
        self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
                                             dtype=np.int32)
240
241
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
242

243
        # req_index -> generator
244
245
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
246
        self.generators: dict[int, torch.Generator] = {}
247

248
        self.num_logprobs: dict[str, int] = {}
249
250
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
251
        self.num_prompt_logprobs: dict[str, int] = {}
252

253
254
255
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

256
257
258
259
260
261
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
262
        self.has_allowed_token_ids: set[str] = set()
263
264
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
265
266
        self.allowed_token_ids_mask: Optional[torch.Tensor] = None
        self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
267

268
269
270
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

271
272
273
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
                                                          dtype=bool)

274
        self.req_output_token_ids: list[Optional[list[int]]] = []
275

276
277
278
279
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()

280
281
282
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

283
284
        self.pooling_params: dict[str, PoolingParams] = {}

285
286
287
288
289
        # Cached reference to the GPU tensor of previously sampled tokens
        self.prev_sampled_token_ids: Optional[torch.Tensor] = None
        self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
        self.prev_req_id_to_index: Optional[dict[str, int]] = None

290
    @property
291
    def req_ids(self) -> list[str]:
292
293
        # None elements should only be present transiently
        # while performing state updates to the batch.
294
        return cast(list[str], self._req_ids)
295

296
    def _register_add_request(self, request: "CachedRequestState") -> int:
297
298
299
300
301
302
303
304
305
306
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
307
308
309
310
311
312
313
314
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
                (new_req_index, request.sampling_params,
                 request.prompt_token_ids, request.output_token_ids))

315
        return new_req_index
316

317
318
319
    def add_request(
        self,
        request: "CachedRequestState",
320
    ) -> int:
321
        req_index = self._register_add_request(request)
322
323

        req_id = request.req_id
324
325
326
327
328
329
330
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids

331
332
333
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
334
335
336
        # OPTIMIZATION: Use np.copyto with pre-converted NumPy arrays
        # instead of slice assignment. This avoids the ~550 µs overhead
        # of converting Python list to NumPy array each time.
337
338
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
            request.prompt_token_ids, request.prompt_embeds)
339
        self.num_prompt_tokens[req_index] = num_prompt_tokens
340
        start_idx = num_prompt_tokens
341
        if request.prompt_token_ids is not None:
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
            if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
                self.token_ids_cpu[
                    req_index, :num_prompt_tokens] = request.prompt_token_ids
            else:
                prompt_token_ids_np = getattr(request, "_prompt_token_ids_np",
                                              None)
                rebuild_prompt_cache = True
                if prompt_token_ids_np is not None:
                    try:
                        rebuild_prompt_cache = (prompt_token_ids_np.dtype !=
                                                np.int32
                                                or prompt_token_ids_np.size !=
                                                num_prompt_tokens)
                    except Exception:
                        rebuild_prompt_cache = True
                if rebuild_prompt_cache:
                    prompt_token_ids_np = np.asarray(request.prompt_token_ids,
                                                     dtype=np.int32)
                    try:
                        request._prompt_token_ids_np = prompt_token_ids_np
                    except Exception:
                        pass
                np.copyto(
                    self.token_ids_cpu[req_index, :num_prompt_tokens],
                    prompt_token_ids_np,
                    casting='no',
                )
369
370
371
372
373
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
374
375
376
377
378
379
380
381
382
383
384
385
386
        if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
            end_idx = start_idx + len(request.output_token_ids)
            self.token_ids_cpu[req_index,
                               start_idx:end_idx] = request.output_token_ids
        else:
            output_token_ids_np = np.asarray(request.output_token_ids,
                                            dtype=np.int32)
            end_idx = start_idx + output_token_ids_np.size
            np.copyto(
                self.token_ids_cpu[req_index, start_idx:end_idx],
                output_token_ids_np,
                casting='no',
            )
zhuwenwen's avatar
zhuwenwen committed
387
388
389
390
391
392
        num_spec_tokens = 0
        if request.spec_token_ids != None:
            num_spec_tokens = len(request.spec_token_ids)
            self.token_ids_cpu[req_index,
                            end_idx:end_idx + num_spec_tokens] = request.spec_token_ids

393

394
395
        self.is_token_ids[req_index, start_idx:end_idx] = True
        # Number of token ids in prompt (token_ids_cpu or prompt_embeds).
396
        # NOTE(woosuk): This may include spec decode tokens.
zhuwenwen's avatar
zhuwenwen committed
397
        self.num_tokens[req_index] = request.num_tokens + num_spec_tokens
398
399
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
400
401

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
402
        self.block_table.add_row(request.block_ids, req_index)
403

404
        if sampling_params := request.sampling_params:
405
406
407
            if (self.is_spec_decode
                    and is_spec_decode_unsupported(sampling_params)):
                self.spec_decode_unsupported_reqs.add(req_id)
408
            if sampling_params.sampling_type == SamplingType.GREEDY:
409
410
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
            self.frequency_penalties_cpu[
                req_index] = sampling_params.frequency_penalty
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
            self.presence_penalties_cpu[
                req_index] = sampling_params.presence_penalty
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
            self.repetition_penalties_cpu[
                req_index] = sampling_params.repetition_penalty
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
444
445
446
                self.num_logprobs[req_id] = (self.vocab_size
                                             if sampling_params.logprobs == -1
                                             else sampling_params.logprobs)
447
            if sampling_params.prompt_logprobs is not None:
448
449
450
                self.num_prompt_logprobs[req_id] = (
                    self.vocab_size if sampling_params.prompt_logprobs == -1
                    else sampling_params.prompt_logprobs)
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device=self.device)
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device="cpu")
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
468
                # False means we don't fill with -inf.
469
470
                self.allowed_token_ids_mask_cpu_tensor[req_index][
                    sampling_params.allowed_token_ids] = False
471

472
473
474
            if sampling_params.bad_words_token_ids:
                self.bad_words_token_ids[
                    req_index] = sampling_params.bad_words_token_ids
475
476
477
478
        elif pooling_params := request.pooling_params:
            self.pooling_params[req_id] = pooling_params
            self.logits_processing_needs_token_ids[req_index] = (
                pooling_params.requires_token_ids)
479
        else:
480
            raise NotImplementedError("Unrecognized request type")
481

482
483
484
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

485
486
487
488
489
490
491
492
493
494
495
496
497
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

498
499
        return req_index

500
    def remove_request(self, req_id: str) -> Optional[int]:
501
        """This method must always be followed by a call to condense().
502

503
504
505
506
507
508
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
509

510
511
512
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
513
514

        self.batch_update_builder.removed_append(req_index)
515
516
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
            return req_index

532
533
534
535
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
536
        self.spec_decode_unsupported_reqs.discard(req_id)
537
538
539
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
540
541
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
542
        self.num_prompt_logprobs.pop(req_id, None)
543
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
544

545
546
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
547
            # False means we don't fill with -inf.
548
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
549
        self.bad_words_token_ids.pop(req_index, None)
550
551
        return req_index

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
        self._req_ids[i1], self._req_ids[i2] =\
            self._req_ids[i2], self._req_ids[i1] # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
            self.req_output_token_ids[i2], self.req_output_token_ids[i1]
        assert old_id_i1 is not None and old_id_i2 is not None
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
            self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
        self.num_tokens[i1], self.num_tokens[i2] =\
            self.num_tokens[i2], self.num_tokens[i1]
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
            self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
            self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
            self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]

571
572
573
574
575
576
577
578
579
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporiarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

580
581
582
583
584
585
586
587
588
589
590
591
592
593
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

594
        self.block_table.swap_row(i1, i2)
595

596
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
597
            self.request_lora_mapping[i2], self.request_lora_mapping[i1]
598

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
        self.batch_update_builder.moved.append(
            (i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = \
            self.temperature_cpu[i2], self.temperature_cpu[i1]
        self.top_p_cpu[i1], self.top_p_cpu[i2] = \
            self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = \
            self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
            self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
            self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
            self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
620
621
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
            self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
622
623
624
625

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

626
627
628
629
630
        if self.allowed_token_ids_mask_cpu_tensor is not None:
            self.allowed_token_ids_mask_cpu_tensor[i1], \
                self.allowed_token_ids_mask_cpu_tensor[i2] =\
                self.allowed_token_ids_mask_cpu_tensor[i2], \
                    self.allowed_token_ids_mask_cpu_tensor[i1]
631

632
633
634
635
636
637
638
639
640
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
641
        """
642
643
        num_reqs = self.num_reqs

644
645
646
647
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
648
        if num_reqs == 0:
649
            # The batched states are empty.
650
651
            self._req_ids.clear()
            self.req_output_token_ids.clear()
652
653
654
655
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
656
        last_req_index = num_reqs + len(empty_req_indices) - 1
657
658
659
660
661
662
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
663
664
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
665
666
667
            if empty_index >= last_req_index:
                break

668
669
670
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
671
672
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
673
            assert req_id is not None
674
675
676
677
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
678
679
            self.req_id_to_index[req_id] = empty_index

680
681
682
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
                last_req_index, :num_tokens]
683
684
685
686
687
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
                last_req_index, :num_tokens]
            if last_req_index in self.req_prompt_embeds:
                self.req_prompt_embeds[
                    empty_index] = self.req_prompt_embeds.pop(last_req_index)
688
            self.num_tokens[empty_index] = num_tokens
689
690
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
                last_req_index]
691
692
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
                last_req_index]
693
694
            self.num_computed_tokens_cpu[
                empty_index] = self.num_computed_tokens_cpu[last_req_index]
695
            self.block_table.move_row(last_req_index, empty_index)
696
697
698
699
700
701

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
                last_req_index]

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
702
                # Sampling state not used by pooling models.
703
704
705
706
707
708
709
710
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
                (last_req_index, empty_index,
                 MoveDirectionality.UNIDIRECTIONAL))

711
712
713
714
            self.temperature_cpu[empty_index] = self.temperature_cpu[
                last_req_index]
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
715
716
717
718
719
720
            self.frequency_penalties_cpu[
                empty_index] = self.frequency_penalties_cpu[last_req_index]
            self.presence_penalties_cpu[
                empty_index] = self.presence_penalties_cpu[last_req_index]
            self.repetition_penalties_cpu[
                empty_index] = self.repetition_penalties_cpu[last_req_index]
721
722
            self.num_accepted_tokens_cpu[
                empty_index] = self.num_accepted_tokens_cpu[last_req_index]
723
724
725
726
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

727
            # TODO convert these to LogitsProcessors
728
729
730
731
732
            if self.allowed_token_ids_mask_cpu_tensor is not None:
                self.allowed_token_ids_mask_cpu_tensor[
                    empty_index] = self.allowed_token_ids_mask_cpu_tensor[
                        last_req_index]

733
734
735
736
            bad_words_token_ids = self.bad_words_token_ids.pop(
                last_req_index, None)
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
737

738
739
740
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

741
        # Trim lists to the batch size.
742
743
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
744

745
    def refresh_metadata(self):
746
        """Apply any batch updates to sampling metadata."""
747

748
        if self.is_pooling_model:
749
750
751
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
752
753
754
755
756
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
757
758
759
760
761
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
762
763
764

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
765
766
767
768
769
        if not self.all_greedy:
            temperature = copy_slice(self.temperature_cpu_tensor,
                                     self.temperature, num_reqs)
        else:
            temperature = None
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
            copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs)
            copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs)
            copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs)

786
787
788
        needs_prompt_token_ids = (
            not self.no_penalties
            or self.logits_processing_needs_token_ids[:num_reqs].any())
789
790
791
792
793
        if needs_prompt_token_ids:
            # The prompt tokens are used only for applying penalties or
            # step pooling during the sampling/pooling process.
            # Hence copy these tensors only when there are requests which
            # need penalties/step_pooler to be applied.
794
795
796
            prompt_token_ids = self._make_prompt_token_ids_tensor()
        else:
            prompt_token_ids = None
797

798
799
800
801
802
803
804
        allowed_token_ids_mask: Optional[torch.Tensor] = None
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
            copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs)
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

805
        return SamplingMetadata(
806
            temperature=temperature,
807
808
            all_greedy=self.all_greedy,
            all_random=self.all_random,
809
810
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
811
812
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
813
814
815
816
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
817
            output_token_ids=cast(list[list[int]], self.req_output_token_ids),
818
            no_penalties=self.no_penalties,
819
            allowed_token_ids_mask=allowed_token_ids_mask,
820
            bad_words_token_ids=self.bad_words_token_ids,
821
            logitsprocs=self.logitsprocs,
822
823
        )

824
825
826
827
828
829
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
830
831
832

        return PoolingMetadata(
            prompt_lens=torch.from_numpy(
833
                self.num_prompt_tokens[:self.num_reqs]),
834
835
836
837
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
        )

838
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
839
840
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
841
842
843
844
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
845
846
            pin_memory=self.pin_memory,
        )
847
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
848
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
849
850
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
851
        for i in range(num_reqs):
852
853
854
855
            prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device,
                                              non_blocking=True)

856
857
    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray
858
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size self.num_reqs where,
               prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

        req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
        prompt_lora_mapping = tuple(req_lora_mapping)
        token_lora_mapping = tuple(
            req_lora_mapping.repeat(num_scheduled_tokens))
874
        active_lora_requests: set[LoRARequest] = set(
875
876
877
878
            self.lora_id_to_lora_request.values())

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

899
900
901
902
903
904
    @property
    def no_penalties(self) -> bool:
        return (len(self.presence_penalties_reqs) == 0
                and len(self.frequency_penalties_reqs) == 0
                and len(self.repetition_penalties_reqs) == 0)

905
    @property
906
907
    def max_num_logprobs(self) -> Optional[int]:
        return max(self.num_logprobs.values()) if self.num_logprobs else None
908
909
910

    @property
    def no_prompt_logprob(self) -> bool:
911
        return not self.num_prompt_logprobs
912
913
914
915

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0