"tests/vscode:/vscode.git/clone" did not exist on "10eb24cc91315481414fba0e0134209e6a9d7c94"
gpu_input_batch.py 44.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.utils import copy_slice
26
from vllm.v1.worker.block_table import MultiGroupBlockTable
27
28
29
30
31


@dataclass
class CachedRequestState:
    req_id: str
32
    prompt_token_ids: list[int] | None
33
    mm_features: list[MultiModalFeatureSpec]
34
35
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
36

37
    block_ids: tuple[list[int], ...]
38
    num_computed_tokens: int
39
    output_token_ids: list[int]
40

41
42
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
43

44
45
    xdrope_positions: torch.Tensor | None = None

46
47
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
48

49
50
51
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

52
53
54
55
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

56
    def __post_init__(self):
57
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
58
59
            self.prompt_token_ids, self.prompt_embeds
        )
60

61
62
63
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

64
65
    @property
    def num_tokens(self) -> int:
66
67
68
69
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
70
71
72
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
73
74
                    "provided via prompt_embeds, and its ID is unknown."
                )
75
            return self.prompt_token_ids[idx]
76
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
77
            return self.output_token_ids[idx - self.num_prompt_tokens]
78
        return -1
79
80
81
82


class InputBatch:
    def __init__(
83
84
85
86
87
88
89
90
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
91
        kernel_block_sizes: list[int],
92
        max_num_blocks_per_req: list[int] | None = None,
93
        logitsprocs: LogitsProcessors | None = None,
94
        logitsprocs_need_output_token_ids: bool = False,
95
        is_spec_decode: bool = False,
96
        is_pooling_model: bool = False,
97
        cp_kv_cache_interleave_size: int = 1,
98
    ):
99
        self.is_pooling_model = is_pooling_model
100
        self.is_spec_decode = is_spec_decode
101
102
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
103
        self.max_num_batched_tokens = max_num_batched_tokens
104
105
        self.device = device
        self.pin_memory = pin_memory
106
        self.vocab_size = vocab_size
107

108
        self._req_ids: list[str | None] = []
109
        self.req_id_to_index: dict[str, int] = {}
110

111
112
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
113
114
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
115
116
117
118
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
119
            pin_memory=False,
120
121
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
122
        self.is_token_ids_tensor = torch.zeros(
123
124
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
125
        self.is_token_ids = self.is_token_ids_tensor.numpy()
126
127
128
129
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
130
131
132
133
134
135
136
        self.num_tokens_no_spec_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
137
138
139
140
141
142
143
        self.num_prompt_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_prompt_tokens = self.num_prompt_tokens_cpu_tensor.numpy()
144
        self.num_computed_tokens_cpu_tensor = torch.zeros(
145
            (max_num_reqs,),
146
147
148
149
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
150
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
151

152
        # Block table.
153
        self.block_table = MultiGroupBlockTable(
154
            max_num_reqs=max_num_reqs,
155
            max_model_len=max_model_len,
156
            max_num_batched_tokens=max_num_batched_tokens,
157
            pin_memory=pin_memory,
158
            device=device,
159
            block_sizes=block_sizes,
160
            kernel_block_sizes=kernel_block_sizes,
161
            max_num_blocks=max_num_blocks_per_req,
162
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
163
164
165
        )

        # Sampling-related.
166
167
168
169
170
171
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
172
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
173
174
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
175

176
177
178
179
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
180
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
181
        self.top_p_reqs: set[str] = set()
182

183
184
185
186
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
187
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
188
        self.top_k_reqs: set[str] = set()
189

190
        # Frequency penalty related data structures
191
192
193
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
194
        self.frequency_penalties_cpu_tensor = torch.empty(
195
196
197
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
198
        self.frequency_penalties_reqs: set[str] = set()
199
200

        # Presence penalty related data structures
201
202
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
203
        )
204
205
206
207
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
208
        self.presence_penalties_reqs: set[str] = set()
209
210

        # Repetition penalty related data structures
211
212
213
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
214
        self.repetition_penalties_cpu_tensor = torch.empty(
215
216
217
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
218
        self.repetition_penalties_reqs: set[str] = set()
219

220
        # Speculative decoding
221
        self.num_accepted_tokens_cpu_tensor = torch.ones(
222
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
223
224
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
225

226
        # lora related
227
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
228
229
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
230

231
        # req_index -> generator
232
233
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
234
        self.generators: dict[int, torch.Generator] = {}
235

236
        self.num_logprobs: dict[str, int] = {}
237

Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
238
239
240
241
        # req_id -> list of specific token IDs to compute logprobs for
        # More efficient than num_logprobs=-1 when only a few tokens are needed
        self.logprob_token_ids: dict[str, list[int]] = {}

242
243
244
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

245
246
247
248
249
250
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
251
        self.has_allowed_token_ids: set[str] = set()
252
253
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
254
255
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
256

257
258
259
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

260
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
261

262
        self.req_output_token_ids: list[list[int] | None] = []
263

264
265
266
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
267
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
268

269
        # Store last speculative tokens for sampler.
270
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
271

272
273
274
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

275
        # for pooling models
276
        self.pooling_params: dict[str, PoolingParams] = {}
277
        self.pooling_states: dict[str, PoolingStates] = {}
278

279
        # Cached reference to the GPU tensor of previously sampled tokens
280
281
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
282
283
284
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
285
        self.sampled_token_ids_cpu: torch.Tensor | None = None
286
        self.async_copy_ready_event: torch.Event | None = None
287

288
    @property
289
    def req_ids(self) -> list[str]:
290
291
        # None elements should only be present transiently
        # while performing state updates to the batch.
292
        return cast(list[str], self._req_ids)
293

294
    def _register_add_request(self, request: "CachedRequestState") -> int:
295
296
297
298
299
300
301
302
303
304
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
305
306
307
308
309
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
310
311
312
313
314
315
316
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
317

318
        return new_req_index
319

320
321
322
    def add_request(
        self,
        request: "CachedRequestState",
323
    ) -> int:
324
        req_index = self._register_add_request(request)
325
326

        req_id = request.req_id
327
328
329
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
330
            self.spec_token_ids.append([])
331
332
333
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
334
            self.spec_token_ids[req_index].clear()
335

336
337
338
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
339
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
340
341
            request.prompt_token_ids, request.prompt_embeds
        )
342
        self.num_prompt_tokens[req_index] = num_prompt_tokens
343
344
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
345
        if request.prompt_token_ids is not None:
346
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
347
348
349
350
351
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
352
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
353
        self.is_token_ids[req_index, start_idx:end_idx] = True
354
355
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
356
357

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
358
        self.block_table.add_row(request.block_ids, req_index)
359

360
361
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
362
363
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
364
365
366
367
368
369
370
371
372
373
374
375
376
377
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
378
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
379
380
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
381
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
382
383
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
384
385
386
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
387
388
389
390
391
392
393
394
395
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
396
397
398
399
400
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
401

Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
402
403
404
405
            # Store specific token IDs to compute logprobs for (more efficient)
            if sampling_params.logprob_token_ids is not None:
                self.logprob_token_ids[req_id] = sampling_params.logprob_token_ids

406
407
408
409
410
411
412
413
414
            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
415
416
                        device=self.device,
                    )
417
418
419
420
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
421
422
                        device="cpu",
                    )
423
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
424
                # False means we don't fill with -inf.
425
                self.allowed_token_ids_mask_cpu_tensor[req_index][
426
427
                    sampling_params.allowed_token_ids
                ] = False
428

429
            if sampling_params.bad_words_token_ids:
430
431
432
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
433
        elif pooling_params := request.pooling_params:
434
435
436
            pooling_states = request.pooling_states
            assert pooling_states is not None

437
            self.pooling_params[req_id] = pooling_params
438
            self.pooling_states[req_id] = pooling_states
439
            self.logits_processing_needs_token_ids[req_index] = (
440
441
                pooling_params.requires_token_ids
            )
442
        else:
443
            raise NotImplementedError("Unrecognized request type")
444

445
446
447
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

448
449
450
451
452
453
454
455
456
457
458
459
460
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

461
462
        return req_index

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

489
    def remove_request(self, req_id: str) -> int | None:
490
        """This method must always be followed by a call to condense().
491

492
493
494
495
496
497
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
498

499
500
501
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
502
503

        self.batch_update_builder.removed_append(req_index)
504
505
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
506
        self.spec_token_ids[req_index].clear()
507
        self.block_table.clear_row(req_index)
508

509
510
511
512
513
514
515
516
517
518
519
520
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
521
            self.pooling_states.pop(req_id, None)
522
523
            return req_index

524
525
526
527
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
528
529
530
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
531
532
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
533
        self.logprob_token_ids.pop(req_id, None)
534
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
535
536
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
537

538
539
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
540
            # False means we don't fill with -inf.
541
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
542
        self.bad_words_token_ids.pop(req_index, None)
543
544
        return req_index

545
546
547
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
548
549
550
551
552
553
        # Only swap the active token prefix for each request. Copying full
        # max_model_len rows is expensive and unnecessary during reordering.
        i1_active_token_count = self._get_active_token_count(i1)
        i2_active_token_count = self._get_active_token_count(i2)
        max_active_token_count = max(i1_active_token_count, i2_active_token_count)

554
555
556
557
558
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
559
560
561
562
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
563
        assert old_id_i1 is not None and old_id_i2 is not None
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
580

581
582
583
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
584
        # instead, we need to temporarily copy the data for one of the indices
585
586
587
588
589
        tmp_token_ids = self.token_ids_cpu[i1, :max_active_token_count].copy()
        self.token_ids_cpu[i1, :max_active_token_count] = self.token_ids_cpu[
            i2, :max_active_token_count
        ]
        self.token_ids_cpu[i2, :max_active_token_count] = tmp_token_ids
590

591
592
593
        self.is_token_ids[[i1, i2], :max_active_token_count] = self.is_token_ids[
            [i2, i1], :max_active_token_count
        ]
594
595
596
597
598
599
600
601
602
603
604
605
606

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

607
        self.block_table.swap_row(i1, i2)
608

609
610
611
612
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
613

614
615
616
617
618
619
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
644
645
646
647

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

648
        if self.allowed_token_ids_mask_cpu_tensor is not None:
649
650
651
652
653
654
655
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
656

657
658
659
660
661
    def _get_active_token_count(self, req_index: int) -> int:
        return int(self.num_tokens_no_spec[req_index]) + len(
            self.spec_token_ids[req_index]
        )

662
663
664
665
666
667
668
669
670
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
671
        """
672
673
        num_reqs = self.num_reqs

674
675
676
677
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
678
        if num_reqs == 0:
679
            # The batched states are empty.
680
681
            self._req_ids.clear()
            self.req_output_token_ids.clear()
682
            self.spec_token_ids.clear()
683
684
685
686
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
687
        last_req_index = num_reqs + len(empty_req_indices) - 1
688
689
690
691
692
693
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
694
695
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
696
697
698
            if empty_index >= last_req_index:
                break

699
700
701
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
702
703
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
704
            assert req_id is not None
705
706
707
708
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
709
710
            self.req_id_to_index[req_id] = empty_index

711
            num_tokens = self._get_active_token_count(last_req_index)
712
713
714
715
716
717

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
718

719
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
720
721
                last_req_index, :num_tokens
            ]
722
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
723
724
                last_req_index, :num_tokens
            ]
725
            if last_req_index in self.req_prompt_embeds:
726
727
728
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
729
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
730
731
732
733
734
735
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
736
            self.block_table.move_row(last_req_index, empty_index)
737
738

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
739
740
                last_req_index
            ]
741
742
743

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
744
                # Sampling state not used by pooling models.
745
746
747
748
749
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
750
751
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
752

753
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
754
755
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
756
757
758
759
760
761
762
763
764
765
766
767
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
768
769
770
771
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

772
            # TODO convert these to LogitsProcessors
773
            if self.allowed_token_ids_mask_cpu_tensor is not None:
774
775
776
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
777

778
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
779
780
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
781

782
783
784
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

785
        # Trim lists to the batch size.
786
787
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
788
        del self.spec_token_ids[num_reqs:]
789

790
    def refresh_metadata(self):
791
        """Apply any batch updates to sampling metadata."""
792

793
        if self.is_pooling_model:
794
795
796
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
797
798
799
800
801
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
802
803
804
805
806
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
807
808
809

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
810
        if not self.all_greedy:
811
812
813
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
814
815
        else:
            temperature = None
816
817
818
819
820
821
822
823
824
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
825
826
827
828
829
830
831
832
833
834
835
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
836

837
838
        needs_prompt_token_ids = (
            not self.no_penalties
839
840
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
841
842
843
844
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
845
846
847
        prompt_token_ids_cpu = (
            self._make_prompt_token_ids_cpu_tensor() if needs_prompt_token_ids else None
        )
848
        prompt_token_ids = (
849
850
851
            prompt_token_ids_cpu.to(device=self.device, non_blocking=True)
            if prompt_token_ids_cpu is not None
            else None
852
        )
853

854
855
856
857
858
859
860
861
862
863
864
865
866
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

867
        allowed_token_ids_mask: torch.Tensor | None = None
868
869
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
870
871
872
873
874
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
875
876
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
877
878
879
880
881
882
883
884
885
        # Build per-request logprob_token_ids mapping: req_index -> token_ids
        logprob_token_ids_by_index: dict[int, list[int]] | None = None
        if self.logprob_token_ids:
            logprob_token_ids_by_index = {}
            for req_id, token_ids in self.logprob_token_ids.items():
                if req_id in self.req_id_to_index:
                    req_index = self.req_id_to_index[req_id]
                    logprob_token_ids_by_index[req_index] = token_ids

886
        return SamplingMetadata(
887
            temperature=temperature,
888
889
            all_greedy=self.all_greedy,
            all_random=self.all_random,
890
891
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
892
893
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
894
            logprob_token_ids=logprob_token_ids_by_index,
895
896
897
898
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
899
            output_token_ids=output_token_ids,
900
            spec_token_ids=self.spec_token_ids,
901
            no_penalties=self.no_penalties,
902
            allowed_token_ids_mask=allowed_token_ids_mask,
903
            bad_words_token_ids=self.bad_words_token_ids,
904
            logitsprocs=self.logitsprocs,
905
906
        )

907
908
909
910
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

911
912
913
914
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

915
916
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
917
        pooling_states = self.get_pooling_states()
918
919
920
        prompt_token_ids_cpu = None
        if any(p.requires_token_ids for p in pooling_params):
            prompt_token_ids_cpu = self._make_prompt_token_ids_cpu_tensor()
921
922

        return PoolingMetadata(
923
            prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
924
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
925
            prompt_token_ids_cpu=prompt_token_ids_cpu,
926
            pooling_params=pooling_params,
927
            pooling_states=pooling_states,
928
929
        )

930
    def _make_prompt_token_ids_cpu_tensor(self) -> torch.Tensor:
931
932
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
933
934
935
936
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
937
938
            pin_memory=self.pin_memory,
        )
939
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
940
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
941
942
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
943
        for i in range(num_reqs):
944
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
945
        return prompt_token_ids_cpu_tensor
946

947
    def make_lora_inputs(
948
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
949
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
950
951
952
953
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
954
955
956
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
957
958
959
960
961
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

962
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
963
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
964
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
965

966
        active_lora_requests: set[LoRARequest] = set(
967
968
            self.lora_id_to_lora_request.values()
        )
969
970
971

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

972
973
974
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
975
        async_copy_ready_event: torch.Event,
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
1014
1015
1016
1017
1018
1019
1020
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
1021
1022
1023
            # output placeholders (tokens can be discarded after kv-load
            # failure) or a larger number (async spec decode adds optimistic
            # placeholders that may exceed the actual acceptance count).
1024
1025
1026
1027
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
1028
1029
            req_output_token_ids[first_placeholder:] = new_ids
            # ^ Implicitly resizes to (first_placeholder + num_to_replace)
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
1050

1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1071
1072
    @property
    def no_penalties(self) -> bool:
1073
1074
1075
1076
1077
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1078

1079
    @property
1080
    def max_num_logprobs(self) -> int | None:
1081
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1082

1083
1084
1085
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0