gpu_input_batch.py 39.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
26
from vllm.v1.utils import copy_slice
27
from vllm.v1.worker.block_table import MultiGroupBlockTable
28
29
30
31
32


@dataclass
class CachedRequestState:
    req_id: str
33
    prompt_token_ids: list[int] | None
34
    mm_features: list[MultiModalFeatureSpec]
35
36
37
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
42

43
44
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
45

46
47
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
48

49
    def __post_init__(self):
50
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
51
52
            self.prompt_token_ids, self.prompt_embeds
        )
53

54
55
    @property
    def num_tokens(self) -> int:
56
57
58
59
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
60
61
62
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
63
64
                    "provided via prompt_embeds, and its ID is unknown."
                )
65
            return self.prompt_token_ids[idx]
66
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
67
            return self.output_token_ids[idx - self.num_prompt_tokens]
68
        return -1
69
70
71
72


class InputBatch:
    def __init__(
73
74
75
76
77
78
79
80
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
81
        kernel_block_sizes: list[int],
82
        logitsprocs: LogitsProcessors | None = None,
83
        logitsprocs_need_output_token_ids: bool = False,
84
        is_spec_decode: bool = False,
85
        is_pooling_model: bool = False,
86
        num_speculative_tokens: int = 0,
87
    ):
88
        self.is_pooling_model = is_pooling_model
89
        self.is_spec_decode = is_spec_decode
90
91
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
92
        self.max_num_batched_tokens = max_num_batched_tokens
93
94
        self.device = device
        self.pin_memory = pin_memory
95
        self.vocab_size = vocab_size
96

97
        self._req_ids: list[str | None] = []
98
        self.req_id_to_index: dict[str, int] = {}
99

100
101
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
102
103
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
104
105
106
107
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
108
            pin_memory=False,
109
110
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
111
        self.is_token_ids_tensor = torch.zeros(
112
113
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
114
        self.is_token_ids = self.is_token_ids_tensor.numpy()
115
116
117
118
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
119
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
120
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
121
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
122
        self.num_computed_tokens_cpu_tensor = torch.zeros(
123
            (max_num_reqs,),
124
125
126
127
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
128
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
129

130
        # Block table.
131
        self.block_table = MultiGroupBlockTable(
132
            max_num_reqs=max_num_reqs,
133
            max_model_len=max_model_len,
134
            max_num_batched_tokens=max_num_batched_tokens,
135
            pin_memory=pin_memory,
136
            device=device,
137
            block_sizes=block_sizes,
138
            kernel_block_sizes=kernel_block_sizes,
139
            num_speculative_tokens=num_speculative_tokens,
140
141
142
        )

        # Sampling-related.
143
144
145
146
147
148
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
149
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
150
151
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
152

153
154
155
156
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
157
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
158
        self.top_p_reqs: set[str] = set()
159

160
161
162
163
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
164
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
165
        self.top_k_reqs: set[str] = set()
166

167
168
        # IDs of requests which do not support spec decoding
        self.spec_decode_unsupported_reqs: set[str] = set()
169

170
        # Frequency penalty related data structures
171
172
173
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
174
        self.frequency_penalties_cpu_tensor = torch.empty(
175
176
177
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
178
        self.frequency_penalties_reqs: set[str] = set()
179
180

        # Presence penalty related data structures
181
182
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
183
        )
184
185
186
187
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
188
        self.presence_penalties_reqs: set[str] = set()
189
190

        # Repetition penalty related data structures
191
192
193
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
194
        self.repetition_penalties_cpu_tensor = torch.empty(
195
196
197
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
198
        self.repetition_penalties_reqs: set[str] = set()
199

200
        # Speculative decoding
201
202
203
204
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
205

206
        # lora related
207
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
208
209
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
210

211
        # req_index -> generator
212
213
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
214
        self.generators: dict[int, torch.Generator] = {}
215

216
        self.num_logprobs: dict[str, int] = {}
217
218
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
219
        self.num_prompt_logprobs: dict[str, int] = {}
220

221
222
223
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

224
225
226
227
228
229
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
230
        self.has_allowed_token_ids: set[str] = set()
231
232
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
233
234
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
235

236
237
238
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

239
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
240

241
        self.req_output_token_ids: list[list[int] | None] = []
242

243
244
245
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
246
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
247

248
        # Store last speculative tokens for sampler.
249
        self.spec_token_ids: list[list[int] | None] = []
250

251
252
253
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

254
255
        self.pooling_params: dict[str, PoolingParams] = {}

256
        # Cached reference to the GPU tensor of previously sampled tokens
257
258
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
259
260
261
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
262
263
        self.sampled_token_ids_cpu: torch.Tensor | None = None
        self.async_copy_ready_event: torch.cuda.Event | None = None
264

265
    @property
266
    def req_ids(self) -> list[str]:
267
268
        # None elements should only be present transiently
        # while performing state updates to the batch.
269
        return cast(list[str], self._req_ids)
270

271
    def _register_add_request(self, request: "CachedRequestState") -> int:
272
273
274
275
276
277
278
279
280
281
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
282
283
284
285
286
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
287
288
289
290
291
292
293
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
294

295
        return new_req_index
296

297
298
299
    def add_request(
        self,
        request: "CachedRequestState",
300
    ) -> int:
301
        req_index = self._register_add_request(request)
302
303

        req_id = request.req_id
304
305
306
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
307
            self.spec_token_ids.append([])
308
309
310
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
311
            self.spec_token_ids[req_index] = []
312

313
314
315
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
316
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
317
318
            request.prompt_token_ids, request.prompt_embeds
        )
319
        self.num_prompt_tokens[req_index] = num_prompt_tokens
320
321
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
322
        if request.prompt_token_ids is not None:
323
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
324
325
326
327
328
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
329
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
330
331
        self.is_token_ids[req_index, start_idx:end_idx] = True
        # Number of token ids in prompt (token_ids_cpu or prompt_embeds).
332
        # NOTE(woosuk): This may include spec decode tokens.
333
        self.num_tokens[req_index] = request.num_tokens
334
335
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
336
337

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
338
        self.block_table.add_row(request.block_ids, req_index)
339

340
        if sampling_params := request.sampling_params:
341
            if self.is_spec_decode and is_spec_decode_unsupported(sampling_params):
342
                self.spec_decode_unsupported_reqs.add(req_id)
343
            if sampling_params.sampling_type == SamplingType.GREEDY:
344
345
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
346
347
348
349
350
351
352
353
354
355
356
357
358
359
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
360
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
361
362
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
363
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
364
365
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
366
367
368
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
369
370
371
372
373
374
375
376
377
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
378
379
380
381
382
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
383
            if sampling_params.prompt_logprobs is not None:
384
                self.num_prompt_logprobs[req_id] = (
385
386
387
388
                    self.vocab_size
                    if sampling_params.prompt_logprobs == -1
                    else sampling_params.prompt_logprobs
                )
389
390
391
392
393
394
395
396
397
398

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
399
400
                        device=self.device,
                    )
401
402
403
404
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
405
406
                        device="cpu",
                    )
407
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
408
                # False means we don't fill with -inf.
409
                self.allowed_token_ids_mask_cpu_tensor[req_index][
410
411
                    sampling_params.allowed_token_ids
                ] = False
412

413
            if sampling_params.bad_words_token_ids:
414
415
416
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
417
418
419
        elif pooling_params := request.pooling_params:
            self.pooling_params[req_id] = pooling_params
            self.logits_processing_needs_token_ids[req_index] = (
420
421
                pooling_params.requires_token_ids
            )
422
        else:
423
            raise NotImplementedError("Unrecognized request type")
424

425
426
427
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

428
429
430
431
432
433
434
435
436
437
438
439
440
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

441
442
        return req_index

443
    def remove_request(self, req_id: str) -> int | None:
444
        """This method must always be followed by a call to condense().
445

446
447
448
449
450
451
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
452

453
454
455
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
456
457

        self.batch_update_builder.removed_append(req_index)
458
459
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
460
        self.spec_token_ids[req_index] = None
461

462
463
464
465
466
467
468
469
470
471
472
473
474
475
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
            return req_index

476
477
478
479
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
480
        self.spec_decode_unsupported_reqs.discard(req_id)
481
482
483
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
484
485
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
486
        self.num_prompt_logprobs.pop(req_id, None)
487
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
488

489
490
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
491
            # False means we don't fill with -inf.
492
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
493
        self.bad_words_token_ids.pop(req_index, None)
494
495
        return req_index

496
497
498
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
499
500
501
502
503
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
504
505
506
507
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
508
        assert old_id_i1 is not None and old_id_i2 is not None
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens[i1], self.num_tokens[i2] = (
            self.num_tokens[i2],
            self.num_tokens[i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
529

530
531
532
533
534
535
536
537
538
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporiarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

539
540
541
542
543
544
545
546
547
548
549
550
551
552
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

553
        self.block_table.swap_row(i1, i2)
554

555
556
557
558
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
559

560
561
562
563
564
565
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
590
591
592
593

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

594
        if self.allowed_token_ids_mask_cpu_tensor is not None:
595
596
597
598
599
600
601
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
602

603
604
605
606
607
608
609
610
611
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
612
        """
613
614
        num_reqs = self.num_reqs

615
616
617
618
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
619
        if num_reqs == 0:
620
            # The batched states are empty.
621
622
            self._req_ids.clear()
            self.req_output_token_ids.clear()
623
            self.spec_token_ids.clear()
624
625
626
627
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
628
        last_req_index = num_reqs + len(empty_req_indices) - 1
629
630
631
632
633
634
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
635
636
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
637
638
639
            if empty_index >= last_req_index:
                break

640
641
642
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
643
644
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
645
            assert req_id is not None
646
647
648
649
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
650
651
            self.req_id_to_index[req_id] = empty_index

652
653
654
655
            spec_token_ids = self.spec_token_ids[last_req_index]
            self.spec_token_ids[empty_index] = spec_token_ids
            self.spec_token_ids[last_req_index] = None

656
657
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
658
659
                last_req_index, :num_tokens
            ]
660
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
661
662
                last_req_index, :num_tokens
            ]
663
            if last_req_index in self.req_prompt_embeds:
664
665
666
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
667
            self.num_tokens[empty_index] = num_tokens
668
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
669
670
671
672
673
674
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
675
            self.block_table.move_row(last_req_index, empty_index)
676
677

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
678
679
                last_req_index
            ]
680
681
682

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
683
                # Sampling state not used by pooling models.
684
685
686
687
688
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
689
690
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
691

692
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
693
694
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
695
696
697
698
699
700
701
702
703
704
705
706
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
707
708
709
710
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

711
            # TODO convert these to LogitsProcessors
712
            if self.allowed_token_ids_mask_cpu_tensor is not None:
713
714
715
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
716

717
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
718
719
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
720

721
722
723
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

724
        # Trim lists to the batch size.
725
726
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
727
        del self.spec_token_ids[num_reqs:]
728

729
    def refresh_metadata(self):
730
        """Apply any batch updates to sampling metadata."""
731

732
        if self.is_pooling_model:
733
734
735
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
736
737
738
739
740
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
741
742
743
744
745
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
746
747
748

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
749
        if not self.all_greedy:
750
751
752
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
753
754
        else:
            temperature = None
755
756
757
758
759
760
761
762
763
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
764
765
766
767
768
769
770
771
772
773
774
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
775

776
777
        needs_prompt_token_ids = (
            not self.no_penalties
778
779
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
780
781
782
783
784
785
786
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
787

788
789
790
791
792
793
794
795
796
797
798
799
800
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

801
        allowed_token_ids_mask: torch.Tensor | None = None
802
803
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
804
805
806
807
808
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
809
810
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

811
        return SamplingMetadata(
812
            temperature=temperature,
813
814
            all_greedy=self.all_greedy,
            all_random=self.all_random,
815
816
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
817
818
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
819
820
821
822
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
823
            output_token_ids=output_token_ids,
824
            spec_token_ids=cast(list[list[int]], self.spec_token_ids),
825
            no_penalties=self.no_penalties,
826
            allowed_token_ids_mask=allowed_token_ids_mask,
827
            bad_words_token_ids=self.bad_words_token_ids,
828
            logitsprocs=self.logitsprocs,
829
830
        )

831
832
833
834
835
836
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
837
838

        return PoolingMetadata(
839
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
840
841
842
843
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
        )

844
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
845
846
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
847
848
849
850
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
851
852
            pin_memory=self.pin_memory,
        )
853
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
854
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
855
856
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
857
        for i in range(num_reqs):
858
859
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
860

861
862
    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray
863
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
864
865
866
867
868
869
870
871
872
873
874
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size self.num_reqs where,
               prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

875
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
876
        prompt_lora_mapping = tuple(req_lora_mapping)
877
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
878
        active_lora_requests: set[LoRARequest] = set(
879
880
            self.lora_id_to_lora_request.values()
        )
881
882
883

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
        async_copy_ready_event: torch.cuda.Event,
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
                sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
            # Replace placeholder token id with actual sampled id.
            req_output_token_ids[-1] = sampled_token_ids[prev_index]

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

950
951
    @property
    def no_penalties(self) -> bool:
952
953
954
955
956
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
957

958
    @property
959
    def max_num_logprobs(self) -> int | None:
960
        return max(self.num_logprobs.values()) if self.num_logprobs else None
961
962
963

    @property
    def no_prompt_logprob(self) -> bool:
964
        return not self.num_prompt_logprobs
965
966
967
968

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0