gpu_input_batch.py 43.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import Optional, cast
7
8
9
10

import numpy as np
import torch

11
12
from vllm import envs
from vllm.config import get_current_vllm_config
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams, SamplingType
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
from vllm.utils.collection_utils import swap_dict_values
19
from vllm.v1.outputs import LogprobsTensors
20
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
21
22
23
24
25
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
26
from vllm.v1.sample.metadata import SamplingMetadata
27
from vllm.v1.utils import copy_slice
28
from vllm.v1.worker.block_table import MultiGroupBlockTable
29
30
31
32
33


@dataclass
class CachedRequestState:
    req_id: str
34
    prompt_token_ids: list[int] | None
35
    mm_features: list[MultiModalFeatureSpec]
36
37
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    num_kv_tokens: int
42
    output_token_ids: list[int]
43

44
45
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
46

47
48
    xdrope_positions: torch.Tensor | None = None

49
50
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
51

52
53
54
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

55
56
57
58
59
60
    # Chunked prefill (scheme 3): cached prompt compaction plan.
    # Computed on the last prompt chunk; applied before the first decode step.
    kv_compression_prompt_idx_sorted: torch.Tensor | None = None  # [K] int32
    kv_compression_prompt_keep_len: int | None = None
    kv_compression_prompt_prompt_len: int | None = None

61
62
63
64
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

65
    def __post_init__(self):
66
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
67
68
            self.prompt_token_ids, self.prompt_embeds
        )
69

70
71
72
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

73
74
    @property
    def num_tokens(self) -> int:
75
76
77
78
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
79
80
81
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
82
83
                    "provided via prompt_embeds, and its ID is unknown."
                )
84
            return self.prompt_token_ids[idx]
85
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
86
            return self.output_token_ids[idx - self.num_prompt_tokens]
87
        return -1
88
89
90
91


class InputBatch:
    def __init__(
92
93
94
95
96
97
98
99
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
100
        kernel_block_sizes: list[int],
101
        max_num_blocks_per_req: list[int] | None = None,
102
        logitsprocs: LogitsProcessors | None = None,
103
        logitsprocs_need_output_token_ids: bool = False,
104
        is_spec_decode: bool = False,
105
        is_pooling_model: bool = False,
106
        cp_kv_cache_interleave_size: int = 1,
107
    ):
108
109
110
111
112
        ori_max_num_reqs = max_num_reqs
        if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
            vllm_config = get_current_vllm_config()
            max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)

113
        self.is_pooling_model = is_pooling_model
114
        self.is_spec_decode = is_spec_decode
115
116
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
117
        self.max_num_batched_tokens = max_num_batched_tokens
118
119
        self.device = device
        self.pin_memory = pin_memory
120
        self.vocab_size = vocab_size
121

122
        self._req_ids: list[str | None] = []
123
        self.req_id_to_index: dict[str, int] = {}
124

125
126
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
127
128
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
129
        self.token_ids_cpu_tensor = torch.zeros(
130
            (ori_max_num_reqs, max_model_len),
131
132
            device="cpu",
            dtype=torch.int32,
133
            pin_memory=False,
134
135
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
136
        self.is_token_ids_tensor = torch.zeros(
137
138
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
139
        self.is_token_ids = self.is_token_ids_tensor.numpy()
140
141
142
143
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
144
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
145
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
146
        self.num_computed_tokens_cpu_tensor = torch.zeros(
147
            (max_num_reqs,),
148
149
150
151
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
152
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
153
154
155
156
157
158
159
        self.num_kv_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
160

161
        # Block table.
162
        self.block_table = MultiGroupBlockTable(
163
            max_num_reqs=max_num_reqs,
164
            max_model_len=max_model_len,
165
            max_num_batched_tokens=max_num_batched_tokens,
166
            pin_memory=pin_memory,
167
            device=device,
168
            block_sizes=block_sizes,
169
            kernel_block_sizes=kernel_block_sizes,
170
            max_num_blocks=max_num_blocks_per_req,
171
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
172
173
174
        )

        # Sampling-related.
175
176
177
178
179
180
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
181
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
182
183
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
184

185
186
187
188
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
189
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
190
        self.top_p_reqs: set[str] = set()
191

192
193
194
195
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
196
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
197
        self.top_k_reqs: set[str] = set()
198

199
        # Frequency penalty related data structures
200
201
202
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
203
        self.frequency_penalties_cpu_tensor = torch.empty(
204
205
206
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
207
        self.frequency_penalties_reqs: set[str] = set()
208
209

        # Presence penalty related data structures
210
211
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
212
        )
213
214
215
216
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
217
        self.presence_penalties_reqs: set[str] = set()
218
219

        # Repetition penalty related data structures
220
221
222
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
223
        self.repetition_penalties_cpu_tensor = torch.empty(
224
225
226
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
227
        self.repetition_penalties_reqs: set[str] = set()
228

229
        # Speculative decoding
230
231
232
233
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
234

235
        # lora related
236
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
237
238
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
239

240
        # req_index -> generator
241
242
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
243
        self.generators: dict[int, torch.Generator] = {}
244

245
        self.num_logprobs: dict[str, int] = {}
246

247
248
249
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

250
251
252
253
254
255
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
256
        self.has_allowed_token_ids: set[str] = set()
257
258
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
259
260
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
261

262
263
264
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

265
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
266

267
        self.req_output_token_ids: list[list[int] | None] = []
268

269
270
271
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
272
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
273

274
        # Store last speculative tokens for sampler.
275
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
276

277
278
279
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

280
        # for pooling models
281
        self.pooling_params: dict[str, PoolingParams] = {}
282
        self.pooling_states: dict[str, PoolingStates] = {}
283

284
        # Cached reference to the GPU tensor of previously sampled tokens
285
286
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
287
288
289
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
290
        self.sampled_token_ids_cpu: torch.Tensor | None = None
291
        self.async_copy_ready_event: torch.Event | None = None
292

293
    @property
294
    def req_ids(self) -> list[str]:
295
296
        # None elements should only be present transiently
        # while performing state updates to the batch.
297
        return cast(list[str], self._req_ids)
298

299
    def _register_add_request(self, request: "CachedRequestState") -> int:
300
301
302
303
304
305
306
307
308
309
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
310
311
312
313
314
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
315
316
317
318
319
320
321
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
322

323
        return new_req_index
324

325
326
327
    def add_request(
        self,
        request: "CachedRequestState",
328
    ) -> int:
329
        req_index = self._register_add_request(request)
330
331

        req_id = request.req_id
332
333
334
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
335
            self.spec_token_ids.append([])
336
337
338
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
339
            self.spec_token_ids[req_index].clear()
340

341
342
343
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
344
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
345
346
            request.prompt_token_ids, request.prompt_embeds
        )
347
        self.num_prompt_tokens[req_index] = num_prompt_tokens
348
349
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
350
        if request.prompt_token_ids is not None:
351
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
352
353
354
355
356
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
357
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
358
        self.is_token_ids[req_index, start_idx:end_idx] = True
359
360
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
361
362

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
363
        self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
364
        self.block_table.add_row(request.block_ids, req_index)
365

366
367
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
368
369
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
370
371
372
373
374
375
376
377
378
379
380
381
382
383
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
384
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
385
386
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
387
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
388
389
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
390
391
392
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
393
394
395
396
397
398
399
400
401
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
402
403
404
405
406
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
407
408
409
410
411
412
413
414
415
416

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
417
418
                        device=self.device,
                    )
419
420
421
422
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
423
424
                        device="cpu",
                    )
425
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
426
                # False means we don't fill with -inf.
427
                self.allowed_token_ids_mask_cpu_tensor[req_index][
428
429
                    sampling_params.allowed_token_ids
                ] = False
430

431
            if sampling_params.bad_words_token_ids:
432
433
434
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
435
        elif pooling_params := request.pooling_params:
436
437
438
            pooling_states = request.pooling_states
            assert pooling_states is not None

439
            self.pooling_params[req_id] = pooling_params
440
            self.pooling_states[req_id] = pooling_states
441
            self.logits_processing_needs_token_ids[req_index] = (
442
443
                pooling_params.requires_token_ids
            )
444
        else:
445
            raise NotImplementedError("Unrecognized request type")
446

447
448
449
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

450
451
452
453
454
455
456
457
458
459
460
461
462
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

463
464
        return req_index

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

491
    def remove_request(self, req_id: str) -> int | None:
492
        """This method must always be followed by a call to condense().
493

494
495
496
497
498
499
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
500

501
502
503
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
504
505

        self.batch_update_builder.removed_append(req_index)
506
507
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
508
        self.spec_token_ids[req_index].clear()
509

510
511
512
513
514
515
516
517
518
519
520
521
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
522
            self.pooling_states.pop(req_id, None)
523
524
            return req_index

525
526
527
528
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
529
530
531
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
532
533
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
534
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
535
536
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
537

538
539
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
540
            # False means we don't fill with -inf.
541
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
542
        self.bad_words_token_ids.pop(req_index, None)
543
544
        return req_index

545
546
547
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
548
549
550
551
552
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
553
554
555
556
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
557
        assert old_id_i1 is not None and old_id_i2 is not None
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
574
575
576
577
        self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] = (
            self.num_kv_tokens_cpu[i2],
            self.num_kv_tokens_cpu[i1],
        )
578

579
580
581
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
582
        # instead, we need to temporarily copy the data for one of the indices
583
584
585
586
587
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

588
589
590
591
592
593
594
595
596
597
598
599
600
601
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

602
        self.block_table.swap_row(i1, i2)
603

604
605
606
607
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
608

609
610
611
612
613
614
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
639
640
641
642

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

643
        if self.allowed_token_ids_mask_cpu_tensor is not None:
644
645
646
647
648
649
650
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
651

652
653
654
655
656
657
658
659
660
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
661
        """
662
663
        num_reqs = self.num_reqs

664
665
666
667
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
668
        if num_reqs == 0:
669
            # The batched states are empty.
670
671
            self._req_ids.clear()
            self.req_output_token_ids.clear()
672
            self.spec_token_ids.clear()
673
674
675
676
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
677
        last_req_index = num_reqs + len(empty_req_indices) - 1
678
679
680
681
682
683
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
684
685
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
686
687
688
            if empty_index >= last_req_index:
                break

689
690
691
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
692
693
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
694
            assert req_id is not None
695
696
697
698
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
699
700
            self.req_id_to_index[req_id] = empty_index

701
702
703
704
705
706
707
708
709
            num_tokens = self.num_tokens_no_spec[last_req_index] + len(
                self.spec_token_ids[last_req_index]
            )

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
710

711
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
712
713
                last_req_index, :num_tokens
            ]
714
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
715
716
                last_req_index, :num_tokens
            ]
717
            if last_req_index in self.req_prompt_embeds:
718
719
720
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
721
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
722
723
724
725
726
727
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
728
            self.num_kv_tokens_cpu[empty_index] = self.num_kv_tokens_cpu[last_req_index]
729
            self.block_table.move_row(last_req_index, empty_index)
730
731

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
732
733
                last_req_index
            ]
734
735
736

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
737
                # Sampling state not used by pooling models.
738
739
740
741
742
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
743
744
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
745

746
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
747
748
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
749
750
751
752
753
754
755
756
757
758
759
760
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
761
762
763
764
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

765
            # TODO convert these to LogitsProcessors
766
            if self.allowed_token_ids_mask_cpu_tensor is not None:
767
768
769
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
770

771
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
772
773
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
774

775
776
777
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

778
        # Trim lists to the batch size.
779
780
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
781
        del self.spec_token_ids[num_reqs:]
782

783
    def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
784
        """Apply any batch updates to sampling metadata."""
785

786
        if self.is_pooling_model:
787
788
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
789
                self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
790
791
792
793
794
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
795
796
797
798
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
799
            self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
800

801
    def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
802
        num_reqs = self.num_reqs
803
        if not self.all_greedy:
804
            temperature = copy_slice(
805
806
                self.temperature_cpu_tensor, self.temperature, 
                num_reqs, repeat_counts
807
            )
808
809
        else:
            temperature = None
810

811
        if not self.no_top_p:
812
            top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
813
        if not self.no_top_k:
814
            top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
815
816
817
818
819

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
820
821
822
823
824
825
826
827
828
            frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs,
                       repeat_counts)
            presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs,
                       repeat_counts)
            repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs,
                       repeat_counts)
829

830
831
        needs_prompt_token_ids = (
            not self.no_penalties
832
833
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
834
835
836
837
838
839
840
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
841

842
843
844
845
846
847
848
849
850
851
852
853
854
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

855
        allowed_token_ids_mask: torch.Tensor | None = None
856
857
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
858
859
860
            allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs,
                       repeat_counts)
861

862
        return SamplingMetadata(
863
            temperature=temperature,
864
865
            all_greedy=self.all_greedy,
            all_random=self.all_random,
866
867
            top_p=None if self.no_top_p else top_p,
            top_k=None if self.no_top_k else top_k,
868
869
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
870
            prompt_token_ids=prompt_token_ids,
871
872
873
            frequency_penalties=None if self.no_penalties else frequency_penalties,
            presence_penalties=None if self.no_penalties else presence_penalties,
            repetition_penalties=None if self.no_penalties else repetition_penalties,
874
            output_token_ids=output_token_ids,
875
            spec_token_ids=self.spec_token_ids,
876
            no_penalties=self.no_penalties,
877
            allowed_token_ids_mask=allowed_token_ids_mask,
878
            bad_words_token_ids=self.bad_words_token_ids,
879
            logitsprocs=self.logitsprocs,
880
881
        )

882
883
884
885
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

886
887
888
889
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

890
891
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
892
        pooling_states = self.get_pooling_states()
893
894

        return PoolingMetadata(
895
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
896
897
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
898
            pooling_states=pooling_states,
899
900
        )

901
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
902
903
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
904
905
906
907
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
908
909
            pin_memory=self.pin_memory,
        )
910
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
911
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
912
913
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
914
        for i in range(num_reqs):
915
916
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
917

918
    def make_lora_inputs(
919
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
920
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
921
922
923
924
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
925
926
927
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
928
929
930
931
932
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

933
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
934
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
935
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
936

937
        active_lora_requests: set[LoRARequest] = set(
938
939
            self.lora_id_to_lora_request.values()
        )
940
941
942

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

943
944
945
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
946
        async_copy_ready_event: torch.Event,
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
            # output placeholders (tokens can be discarded after a kv-load failure).
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
            end_index = first_placeholder + num_to_replace
            req_output_token_ids[first_placeholder:end_index] = new_ids

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
1019

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1040
1041
    @property
    def no_penalties(self) -> bool:
1042
1043
1044
1045
1046
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1047

1048
    @property
1049
    def max_num_logprobs(self) -> int | None:
1050
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1051

1052
1053
1054
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0