gpu_input_batch.py 48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4

5
from dataclasses import dataclass, field
6
from typing import Optional, cast
7
8
9
10

import numpy as np
import torch

11
12
from vllm import envs
from vllm.config import get_current_vllm_config
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams, SamplingType
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
from vllm.utils.collection_utils import swap_dict_values
19
from vllm.v1.outputs import LogprobsTensors
20
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
21
22
23
24
25
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
26
from vllm.v1.sample.metadata import SamplingMetadata
27
from vllm.v1.utils import copy_slice
28
from vllm.v1.worker.block_table import MultiGroupBlockTable
29
30
31
32
33


@dataclass
class CachedRequestState:
    req_id: str
34
    prompt_token_ids: list[int] | None
35
    mm_features: list[MultiModalFeatureSpec]
36
37
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
42

43
44
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
45

46
47
    xdrope_positions: torch.Tensor | None = None

48
49
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
50
51
52
53
54
55
    _prompt_token_ids_np: np.ndarray | None = field(
        default=None,
        init=False,
        repr=False,
        compare=False,
    )
56

57
58
59
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

60
61
62
63
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

luopl's avatar
luopl committed
64
65
66
67
68
69
70
    # for multi layer eagle proposer
    cached_len: torch.Tensor | None = None
    cached_token_ids: torch.Tensor | None = None
    cached_hidden_states: torch.Tensor | None = None
    cached_slot_mappings: torch.Tensor | None = None
    cached_positions: torch.Tensor | None = None

71
    def __post_init__(self):
72
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
73
74
            self.prompt_token_ids, self.prompt_embeds
        )
75

76
77
78
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

79
80
    @property
    def num_tokens(self) -> int:
81
82
83
84
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
85
86
87
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
88
89
                    "provided via prompt_embeds, and its ID is unknown."
                )
90
            return self.prompt_token_ids[idx]
91
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
92
            return self.output_token_ids[idx - self.num_prompt_tokens]
93
        return -1
94
95
96
97


class InputBatch:
    def __init__(
98
99
100
101
102
103
104
105
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
106
        kernel_block_sizes: list[int],
107
        max_num_blocks_per_req: list[int] | None = None,
108
        logitsprocs: LogitsProcessors | None = None,
109
        logitsprocs_need_output_token_ids: bool = False,
110
        is_spec_decode: bool = False,
111
        is_pooling_model: bool = False,
112
        cp_kv_cache_interleave_size: int = 1,
luopl's avatar
luopl committed
113
114
        multi_layer_eagle_num: int = 0,
        hidden_size: int | None = None,
115
    ):
116
117
118
119
120
        ori_max_num_reqs = max_num_reqs
        if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
            vllm_config = get_current_vllm_config()
            max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)

121
        self.is_pooling_model = is_pooling_model
122
        self.is_spec_decode = is_spec_decode
123
124
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
125
        self.max_num_batched_tokens = max_num_batched_tokens
126
127
        self.device = device
        self.pin_memory = pin_memory
128
        self.vocab_size = vocab_size
129

130
        self._req_ids: list[str | None] = []
131
        self.req_id_to_index: dict[str, int] = {}
132

133
134
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
135
136
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
137
        self.token_ids_cpu_tensor = torch.zeros(
138
            (ori_max_num_reqs, max_model_len),
139
140
            device="cpu",
            dtype=torch.int32,
141
            pin_memory=False,
142
143
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
144
        self.is_token_ids_tensor = torch.zeros(
145
146
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
147
        self.is_token_ids = self.is_token_ids_tensor.numpy()
148
149
150
151
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
152
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
153
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
154
        self.num_computed_tokens_cpu_tensor = torch.zeros(
155
            (max_num_reqs,),
156
157
158
159
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
160
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
161

162
        # Block table.
163
        self.block_table = MultiGroupBlockTable(
164
            max_num_reqs=max_num_reqs,
165
            max_model_len=max_model_len,
166
            max_num_batched_tokens=max_num_batched_tokens,
167
            pin_memory=pin_memory,
168
            device=device,
169
            block_sizes=block_sizes,
170
            kernel_block_sizes=kernel_block_sizes,
171
            max_num_blocks=max_num_blocks_per_req,
172
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
173
174
175
        )

        # Sampling-related.
176
177
178
179
180
181
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
182
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
183
184
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
185

186
187
188
189
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
190
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
191
        self.top_p_reqs: set[str] = set()
192

193
194
195
196
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
197
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
198
        self.top_k_reqs: set[str] = set()
199

200
        # Frequency penalty related data structures
201
202
203
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
204
        self.frequency_penalties_cpu_tensor = torch.empty(
205
206
207
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
208
        self.frequency_penalties_reqs: set[str] = set()
209
210

        # Presence penalty related data structures
211
212
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
213
        )
214
215
216
217
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
218
        self.presence_penalties_reqs: set[str] = set()
219
220

        # Repetition penalty related data structures
221
222
223
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
224
        self.repetition_penalties_cpu_tensor = torch.empty(
225
226
227
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
228
        self.repetition_penalties_reqs: set[str] = set()
229

230
        # Speculative decoding
231
232
233
234
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
luopl's avatar
luopl committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        # Multi layer eagle
        self.multi_layer_eagle_num = multi_layer_eagle_num
        if multi_layer_eagle_num > 0:
            self.cached_len = torch.zeros(
                (max_num_reqs,), dtype=torch.int64, device=device
            )
            self.cached_token_ids = torch.zeros(
                (
                    max_num_reqs,
                    multi_layer_eagle_num,
                ),
                dtype=torch.int32,
                device=device,
            )
            self.cached_hidden_states = torch.zeros(
                (
                    max_num_reqs,
                    multi_layer_eagle_num,
                    hidden_size,
                ),
                dtype=torch.float,
                device=device,
            )
            self.cached_slot_mappings = torch.zeros(
                (
                    max_num_reqs,
                    multi_layer_eagle_num,
                ),
                dtype=torch.int64,
                device=device,
            )
            self.cached_positions = torch.zeros(
                (
                    max_num_reqs,
                    multi_layer_eagle_num,
                ),
                dtype=torch.int64,
                device=device,
            )
274
        # lora related
275
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
276
277
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
278

279
        # req_index -> generator
280
281
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
282
        self.generators: dict[int, torch.Generator] = {}
283

284
        self.num_logprobs: dict[str, int] = {}
285

286
287
288
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

289
290
291
292
293
294
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
295
        self.has_allowed_token_ids: set[str] = set()
296
297
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
298
299
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
300

301
302
303
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

304
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
305

306
        self.req_output_token_ids: list[list[int] | None] = []
307

308
309
310
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
311
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
312

313
        # Store last speculative tokens for sampler.
314
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
315

316
317
318
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

319
        # for pooling models
320
        self.pooling_params: dict[str, PoolingParams] = {}
321
        self.pooling_states: dict[str, PoolingStates] = {}
322

323
        # Cached reference to the GPU tensor of previously sampled tokens
324
325
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
326
327
328
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
329
        self.sampled_token_ids_cpu: torch.Tensor | None = None
330
        self.async_copy_ready_event: torch.Event | None = None
331

332
    @property
333
    def req_ids(self) -> list[str]:
334
335
        # None elements should only be present transiently
        # while performing state updates to the batch.
336
        return cast(list[str], self._req_ids)
337

338
    def _register_add_request(self, request: "CachedRequestState") -> int:
339
340
341
342
343
344
345
346
347
348
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
349
350
351
352
353
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
354
355
356
357
358
359
360
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
361

362
        return new_req_index
363

364
365
366
    def add_request(
        self,
        request: "CachedRequestState",
367
    ) -> int:
368
        req_index = self._register_add_request(request)
369
370

        req_id = request.req_id
371
372
373
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
374
            self.spec_token_ids.append([])
375
376
377
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
378
            self.spec_token_ids[req_index].clear()
379

380
381
382
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
383
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
384
385
            request.prompt_token_ids, request.prompt_embeds
        )
386
        self.num_prompt_tokens[req_index] = num_prompt_tokens
387
        start_idx = num_prompt_tokens
388
        if request.prompt_token_ids is not None:
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
                self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
            else:
                prompt_token_ids_np = request._prompt_token_ids_np
                rebuild_prompt_cache = True
                if prompt_token_ids_np is not None:
                    rebuild_prompt_cache = (
                        prompt_token_ids_np.dtype != np.int32
                        or prompt_token_ids_np.size != num_prompt_tokens
                    )
                if rebuild_prompt_cache:
                    prompt_token_ids_np = np.asarray(request.prompt_token_ids, dtype=np.int32)
                    request._prompt_token_ids_np = prompt_token_ids_np
                np.copyto(
                    self.token_ids_cpu[req_index, :num_prompt_tokens],
                    prompt_token_ids_np,
                    casting="no",
                )
407
408
409
410
411
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
412
413
414
415
416
417
418
419
420
421
422
        if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
            end_idx = start_idx + len(request.output_token_ids)
            self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
        else:
            output_token_ids_np = np.asarray(request.output_token_ids, dtype=np.int32)
            end_idx = start_idx + output_token_ids_np.size
            np.copyto(
                self.token_ids_cpu[req_index, start_idx:end_idx],
                output_token_ids_np,
                casting="no",
            )
423
        self.is_token_ids[req_index, start_idx:end_idx] = True
424
425
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
426
427

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
428
        self.block_table.add_row(request.block_ids, req_index)
429

430
431
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
432
433
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
434
435
436
437
438
439
440
441
442
443
444
445
446
447
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
448
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
449
450
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
451
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
452
453
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
454
455
456
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
457
458
459
460
461
462
463
464
465
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
466
467
468
469
470
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
471
472
473
474
475
476
477
478
479
480

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
481
482
                        device=self.device,
                    )
483
484
485
486
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
487
488
                        device="cpu",
                    )
489
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
490
                # False means we don't fill with -inf.
491
                self.allowed_token_ids_mask_cpu_tensor[req_index][
492
493
                    sampling_params.allowed_token_ids
                ] = False
494

495
            if sampling_params.bad_words_token_ids:
496
497
498
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
499
        elif pooling_params := request.pooling_params:
500
501
502
            pooling_states = request.pooling_states
            assert pooling_states is not None

503
            self.pooling_params[req_id] = pooling_params
504
            self.pooling_states[req_id] = pooling_states
505
            self.logits_processing_needs_token_ids[req_index] = (
506
507
                pooling_params.requires_token_ids
            )
508
        else:
509
            raise NotImplementedError("Unrecognized request type")
510

511
512
513
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

luopl's avatar
luopl committed
514
515
516
517
518
519
520
        if self.multi_layer_eagle_num > 0:
            self.cached_len[req_index] = request.cached_len
            self.cached_token_ids[req_index] = request.cached_token_ids
            self.cached_hidden_states[req_index] = request.cached_hidden_states
            self.cached_slot_mappings[req_index] = request.cached_slot_mappings
            self.cached_positions[req_index] = request.cached_positions

521
522
523
524
525
526
527
528
529
530
531
532
533
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

534
535
        return req_index

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

562
    def remove_request(self, req_id: str) -> int | None:
563
        """This method must always be followed by a call to condense().
564

565
566
567
568
569
570
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
571

572
573
574
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
575
576

        self.batch_update_builder.removed_append(req_index)
577
578
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
579
        self.spec_token_ids[req_index].clear()
580

581
582
583
584
585
586
587
588
589
590
591
592
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
593
            self.pooling_states.pop(req_id, None)
594
595
            return req_index

596
597
598
599
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
600
601
602
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
603
604
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
605
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
606
607
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
608

609
610
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
611
            # False means we don't fill with -inf.
612
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
613
        self.bad_words_token_ids.pop(req_index, None)
614
615
        return req_index

616
617
618
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
619
620
621
622
623
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
624
625
626
627
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
628
        assert old_id_i1 is not None and old_id_i2 is not None
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
645

646
647
648
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
649
        # instead, we need to temporarily copy the data for one of the indices
650
651
652
653
654
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

655
656
657
658
659
660
661
662
663
664
665
666
667
668
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

669
        self.block_table.swap_row(i1, i2)
670

671
672
673
674
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
675

676
677
678
679
680
681
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
706
707
708
709

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

710
        if self.allowed_token_ids_mask_cpu_tensor is not None:
711
712
713
714
715
716
717
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
718

luopl's avatar
luopl committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        if self.multi_layer_eagle_num > 0:
            self.cached_len[i1], self.cached_len[i2] = (
                self.cached_len[i2],
                self.cached_len[i1],
            )
            self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...]
            self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
                [i2, i1], ...
            ]
            self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
                [i2, i1], ...
            ]
            self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...]

733
734
735
736
737
738
739
740
741
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
742
        """
743
744
        num_reqs = self.num_reqs

745
746
747
748
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
749
        if num_reqs == 0:
750
            # The batched states are empty.
751
752
            self._req_ids.clear()
            self.req_output_token_ids.clear()
753
            self.spec_token_ids.clear()
754
755
756
757
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
758
        last_req_index = num_reqs + len(empty_req_indices) - 1
759
760
761
762
763
764
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
765
766
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
767
768
769
            if empty_index >= last_req_index:
                break

770
771
772
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
773
774
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
775
            assert req_id is not None
776
777
778
779
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
780
781
            self.req_id_to_index[req_id] = empty_index

782
783
784
785
786
787
788
789
790
            num_tokens = self.num_tokens_no_spec[last_req_index] + len(
                self.spec_token_ids[last_req_index]
            )

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
791

792
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
793
794
                last_req_index, :num_tokens
            ]
795
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
796
797
                last_req_index, :num_tokens
            ]
798
            if last_req_index in self.req_prompt_embeds:
799
800
801
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
802
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
803
804
805
806
807
808
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
809
            self.block_table.move_row(last_req_index, empty_index)
810
811

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
812
813
                last_req_index
            ]
814
815
816

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
817
                # Sampling state not used by pooling models.
818
819
820
821
822
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
823
824
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
825

826
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
827
828
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
829
830
831
832
833
834
835
836
837
838
839
840
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
841
842
843
844
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

845
            # TODO convert these to LogitsProcessors
846
            if self.allowed_token_ids_mask_cpu_tensor is not None:
847
848
849
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
850

851
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
852
853
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
854

luopl's avatar
luopl committed
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
            if self.multi_layer_eagle_num > 0:
                self.cached_len[empty_index] = self.cached_len[last_req_index]
                self.cached_token_ids[empty_index] = self.cached_token_ids[
                    last_req_index
                ]
                self.cached_hidden_states[empty_index] = self.cached_hidden_states[
                    last_req_index
                ]
                self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
                    last_req_index
                ]
                self.cached_positions[empty_index] = self.cached_positions[
                    last_req_index
                ]

870
871
872
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

873
        # Trim lists to the batch size.
874
875
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
876
        del self.spec_token_ids[num_reqs:]
877

878
    def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
879
        """Apply any batch updates to sampling metadata."""
880

881
        if self.is_pooling_model:
882
883
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
884
                self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
885
886
887
888
889
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
890
891
892
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
893
        if batch_update or repeat_counts is not None:
894
            self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
895

896
    def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
897
        num_reqs = self.num_reqs
898
899
900
901
902
903
904
905
906
907
        # Host-side summaries for reduced top-k/top-p sampling.
        # Compute before copy_slice(top_k), which may rewrite top_k_cpu_tensor
        # when repeat_counts is provided.
        max_top_k = None
        has_any_no_top_k = False
        if not self.no_top_k and num_reqs > 0:
            top_k_cpu = self.top_k_cpu[:num_reqs]
            max_top_k = int(top_k_cpu.max())
            has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())

908
        if not self.all_greedy:
909
            temperature = copy_slice(
luopl's avatar
luopl committed
910
                self.temperature_cpu_tensor, self.temperature,
911
                num_reqs, repeat_counts
912
            )
913
914
        else:
            temperature = None
915

916
        if not self.no_top_p:
917
            top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
918
        if not self.no_top_k:
919
            top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
920
921
922
923
924

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
925
926
927
928
929
930
931
932
933
            frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs,
                       repeat_counts)
            presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs,
                       repeat_counts)
            repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs,
                       repeat_counts)
934

935
936
        needs_prompt_token_ids = (
            not self.no_penalties
937
938
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
939
940
941
942
943
944
945
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
946

947
948
949
950
951
952
953
954
955
956
957
958
959
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

960
        allowed_token_ids_mask: torch.Tensor | None = None
961
962
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
963
964
965
            allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs,
                       repeat_counts)
966

967
        return SamplingMetadata(
968
            temperature=temperature,
969
970
            all_greedy=self.all_greedy,
            all_random=self.all_random,
971
972
            top_p=None if self.no_top_p else top_p,
            top_k=None if self.no_top_k else top_k,
973
974
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
975
            prompt_token_ids=prompt_token_ids,
976
977
978
            frequency_penalties=None if self.no_penalties else frequency_penalties,
            presence_penalties=None if self.no_penalties else presence_penalties,
            repetition_penalties=None if self.no_penalties else repetition_penalties,
979
            output_token_ids=output_token_ids,
980
            spec_token_ids=self.spec_token_ids,
981
            no_penalties=self.no_penalties,
982
            allowed_token_ids_mask=allowed_token_ids_mask,
983
            bad_words_token_ids=self.bad_words_token_ids,
984
            logitsprocs=self.logitsprocs,
985
986
            max_top_k=max_top_k,
            has_any_no_top_k=has_any_no_top_k,
987
988
        )

989
990
991
992
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

993
994
995
996
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

997
998
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
999
        pooling_states = self.get_pooling_states()
1000
1001

        return PoolingMetadata(
1002
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
1003
1004
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
1005
            pooling_states=pooling_states,
1006
1007
        )

1008
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
1009
1010
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
1011
1012
1013
1014
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
1015
1016
            pin_memory=self.pin_memory,
        )
1017
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
1018
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
1019
1020
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
1021
        for i in range(num_reqs):
1022
1023
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
1024

1025
    def make_lora_inputs(
1026
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
1027
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
1028
1029
1030
1031
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
1032
1033
1034
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
1035
1036
1037
1038
1039
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

1040
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
1041
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
1042
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
1043

1044
        active_lora_requests: set[LoRARequest] = set(
1045
1046
            self.lora_id_to_lora_request.values()
        )
1047
1048
1049

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

1050
1051
1052
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
1053
        async_copy_ready_event: torch.Event,
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
            # output placeholders (tokens can be discarded after a kv-load failure).
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
            end_index = first_placeholder + num_to_replace
            req_output_token_ids[first_placeholder:end_index] = new_ids

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
1126

1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1147
1148
    @property
    def no_penalties(self) -> bool:
1149
1150
1151
1152
1153
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1154

1155
    @property
1156
    def max_num_logprobs(self) -> int | None:
1157
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1158

1159
1160
1161
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0