"vllm/vscode:/vscode.git/clone" did not exist on "86e9c8df29a954a7a2fc46e9985fecc2a2e15ae8"
gpu_input_batch.py 43.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.utils import copy_slice
26
from vllm.v1.worker.block_table import MultiGroupBlockTable
27
28
29
30
31


@dataclass
class CachedRequestState:
    req_id: str
32
    prompt_token_ids: list[int] | None
33
    mm_features: list[MultiModalFeatureSpec]
34
35
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
36

37
    block_ids: tuple[list[int], ...]
38
    num_computed_tokens: int
39
    output_token_ids: list[int]
40

41
42
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
43

44
45
    xdrope_positions: torch.Tensor | None = None

46
47
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
48

49
50
51
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

52
53
54
55
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

56
    def __post_init__(self):
57
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
58
59
            self.prompt_token_ids, self.prompt_embeds
        )
60

61
62
63
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

64
65
    @property
    def num_tokens(self) -> int:
66
67
68
69
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
70
71
72
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
73
74
                    "provided via prompt_embeds, and its ID is unknown."
                )
75
            return self.prompt_token_ids[idx]
76
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
77
            return self.output_token_ids[idx - self.num_prompt_tokens]
78
        return -1
79
80
81
82


class InputBatch:
    def __init__(
83
84
85
86
87
88
89
90
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
91
        kernel_block_sizes: list[int],
92
        max_num_blocks_per_req: list[int] | None = None,
93
        logitsprocs: LogitsProcessors | None = None,
94
        logitsprocs_need_output_token_ids: bool = False,
95
        is_spec_decode: bool = False,
96
        is_pooling_model: bool = False,
97
        cp_kv_cache_interleave_size: int = 1,
98
    ):
99
        self.is_pooling_model = is_pooling_model
100
        self.is_spec_decode = is_spec_decode
101
102
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
103
        self.max_num_batched_tokens = max_num_batched_tokens
104
105
        self.device = device
        self.pin_memory = pin_memory
106
        self.vocab_size = vocab_size
107

108
        self._req_ids: list[str | None] = []
109
        self.req_id_to_index: dict[str, int] = {}
110

111
112
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
113
114
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
115
116
117
118
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
119
            pin_memory=False,
120
121
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
122
        self.is_token_ids_tensor = torch.zeros(
123
124
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
125
        self.is_token_ids = self.is_token_ids_tensor.numpy()
126
127
128
129
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
130
131
132
133
134
135
136
        self.num_tokens_no_spec_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
137
138
139
140
141
142
143
        self.num_prompt_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_prompt_tokens = self.num_prompt_tokens_cpu_tensor.numpy()
144
        self.num_computed_tokens_cpu_tensor = torch.zeros(
145
            (max_num_reqs,),
146
147
148
149
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
150
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
151

152
        # Block table.
153
        self.block_table = MultiGroupBlockTable(
154
            max_num_reqs=max_num_reqs,
155
            max_model_len=max_model_len,
156
            max_num_batched_tokens=max_num_batched_tokens,
157
            pin_memory=pin_memory,
158
            device=device,
159
            block_sizes=block_sizes,
160
            kernel_block_sizes=kernel_block_sizes,
161
            max_num_blocks=max_num_blocks_per_req,
162
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
163
164
165
        )

        # Sampling-related.
166
167
168
169
170
171
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
172
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
173
174
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
175

176
177
178
179
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
180
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
181
        self.top_p_reqs: set[str] = set()
182

183
184
185
186
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
187
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
188
        self.top_k_reqs: set[str] = set()
189

190
        # Frequency penalty related data structures
191
192
193
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
194
        self.frequency_penalties_cpu_tensor = torch.empty(
195
196
197
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
198
        self.frequency_penalties_reqs: set[str] = set()
199
200

        # Presence penalty related data structures
201
202
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
203
        )
204
205
206
207
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
208
        self.presence_penalties_reqs: set[str] = set()
209
210

        # Repetition penalty related data structures
211
212
213
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
214
        self.repetition_penalties_cpu_tensor = torch.empty(
215
216
217
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
218
        self.repetition_penalties_reqs: set[str] = set()
219

220
        # Speculative decoding
221
        self.num_accepted_tokens_cpu_tensor = torch.ones(
222
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
223
224
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
225

226
        # lora related
227
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
228
229
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
230

231
        # req_index -> generator
232
233
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
234
        self.generators: dict[int, torch.Generator] = {}
235

236
        self.num_logprobs: dict[str, int] = {}
237

238
239
240
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

241
242
243
244
245
246
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
247
        self.has_allowed_token_ids: set[str] = set()
248
249
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
250
251
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
252

253
254
255
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

256
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
257

258
        self.req_output_token_ids: list[list[int] | None] = []
259

260
261
262
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
263
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
264

265
        # Store last speculative tokens for sampler.
266
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
267

268
269
270
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

271
        # for pooling models
272
        self.pooling_params: dict[str, PoolingParams] = {}
273
        self.pooling_states: dict[str, PoolingStates] = {}
274

275
        # Cached reference to the GPU tensor of previously sampled tokens
276
277
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
278
279
280
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
281
        self.sampled_token_ids_cpu: torch.Tensor | None = None
282
        self.async_copy_ready_event: torch.Event | None = None
283

284
    @property
285
    def req_ids(self) -> list[str]:
286
287
        # None elements should only be present transiently
        # while performing state updates to the batch.
288
        return cast(list[str], self._req_ids)
289

290
    def _register_add_request(self, request: "CachedRequestState") -> int:
291
292
293
294
295
296
297
298
299
300
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
301
302
303
304
305
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
306
307
308
309
310
311
312
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
313

314
        return new_req_index
315

316
317
318
    def add_request(
        self,
        request: "CachedRequestState",
319
    ) -> int:
320
        req_index = self._register_add_request(request)
321
322

        req_id = request.req_id
323
324
325
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
326
            self.spec_token_ids.append([])
327
328
329
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
330
            self.spec_token_ids[req_index].clear()
331

332
333
334
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
335
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
336
337
            request.prompt_token_ids, request.prompt_embeds
        )
338
        self.num_prompt_tokens[req_index] = num_prompt_tokens
339
340
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
341
        if request.prompt_token_ids is not None:
342
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
343
344
345
346
347
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
348
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
349
        self.is_token_ids[req_index, start_idx:end_idx] = True
350
351
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
352
353

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
354
        self.block_table.add_row(request.block_ids, req_index)
355

356
357
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
358
359
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
374
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
375
376
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
377
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
378
379
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
380
381
382
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
383
384
385
386
387
388
389
390
391
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
392
393
394
395
396
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
397
398
399
400
401
402
403
404
405
406

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
407
408
                        device=self.device,
                    )
409
410
411
412
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
413
414
                        device="cpu",
                    )
415
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
416
                # False means we don't fill with -inf.
417
                self.allowed_token_ids_mask_cpu_tensor[req_index][
418
419
                    sampling_params.allowed_token_ids
                ] = False
420

421
            if sampling_params.bad_words_token_ids:
422
423
424
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
425
        elif pooling_params := request.pooling_params:
426
427
428
            pooling_states = request.pooling_states
            assert pooling_states is not None

429
            self.pooling_params[req_id] = pooling_params
430
            self.pooling_states[req_id] = pooling_states
431
            self.logits_processing_needs_token_ids[req_index] = (
432
433
                pooling_params.requires_token_ids
            )
434
        else:
435
            raise NotImplementedError("Unrecognized request type")
436

437
438
439
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

440
441
442
443
444
445
446
447
448
449
450
451
452
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

453
454
        return req_index

455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

481
    def remove_request(self, req_id: str) -> int | None:
482
        """This method must always be followed by a call to condense().
483

484
485
486
487
488
489
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
490

491
492
493
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
494
495

        self.batch_update_builder.removed_append(req_index)
496
497
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
498
        self.spec_token_ids[req_index].clear()
499
        self.block_table.clear_row(req_index)
500

501
502
503
504
505
506
507
508
509
510
511
512
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
513
            self.pooling_states.pop(req_id, None)
514
515
            return req_index

516
517
518
519
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
520
521
522
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
523
524
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
525
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
526
527
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
528

529
530
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
531
            # False means we don't fill with -inf.
532
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
533
        self.bad_words_token_ids.pop(req_index, None)
534
535
        return req_index

536
537
538
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
539
540
541
542
543
544
        # Only swap the active token prefix for each request. Copying full
        # max_model_len rows is expensive and unnecessary during reordering.
        i1_active_token_count = self._get_active_token_count(i1)
        i2_active_token_count = self._get_active_token_count(i2)
        max_active_token_count = max(i1_active_token_count, i2_active_token_count)

545
546
547
548
549
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
550
551
552
553
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
554
        assert old_id_i1 is not None and old_id_i2 is not None
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
571

572
573
574
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
575
        # instead, we need to temporarily copy the data for one of the indices
576
577
578
579
580
        tmp_token_ids = self.token_ids_cpu[i1, :max_active_token_count].copy()
        self.token_ids_cpu[i1, :max_active_token_count] = self.token_ids_cpu[
            i2, :max_active_token_count
        ]
        self.token_ids_cpu[i2, :max_active_token_count] = tmp_token_ids
581

582
583
584
        self.is_token_ids[[i1, i2], :max_active_token_count] = self.is_token_ids[
            [i2, i1], :max_active_token_count
        ]
585
586
587
588
589
590
591
592
593
594
595
596
597

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

598
        self.block_table.swap_row(i1, i2)
599

600
601
602
603
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
604

605
606
607
608
609
610
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
635
636
637
638

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

639
        if self.allowed_token_ids_mask_cpu_tensor is not None:
640
641
642
643
644
645
646
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
647

648
649
650
651
652
    def _get_active_token_count(self, req_index: int) -> int:
        return int(self.num_tokens_no_spec[req_index]) + len(
            self.spec_token_ids[req_index]
        )

653
654
655
656
657
658
659
660
661
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
662
        """
663
664
        num_reqs = self.num_reqs

665
666
667
668
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
669
        if num_reqs == 0:
670
            # The batched states are empty.
671
672
            self._req_ids.clear()
            self.req_output_token_ids.clear()
673
            self.spec_token_ids.clear()
674
675
676
677
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
678
        last_req_index = num_reqs + len(empty_req_indices) - 1
679
680
681
682
683
684
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
685
686
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
687
688
689
            if empty_index >= last_req_index:
                break

690
691
692
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
693
694
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
695
            assert req_id is not None
696
697
698
699
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
700
701
            self.req_id_to_index[req_id] = empty_index

702
            num_tokens = self._get_active_token_count(last_req_index)
703
704
705
706
707
708

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
709

710
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
711
712
                last_req_index, :num_tokens
            ]
713
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
714
715
                last_req_index, :num_tokens
            ]
716
            if last_req_index in self.req_prompt_embeds:
717
718
719
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
720
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
721
722
723
724
725
726
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
727
            self.block_table.move_row(last_req_index, empty_index)
728
729

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
730
731
                last_req_index
            ]
732
733
734

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
735
                # Sampling state not used by pooling models.
736
737
738
739
740
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
741
742
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
743

744
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
745
746
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
747
748
749
750
751
752
753
754
755
756
757
758
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
759
760
761
762
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

763
            # TODO convert these to LogitsProcessors
764
            if self.allowed_token_ids_mask_cpu_tensor is not None:
765
766
767
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
768

769
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
770
771
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
772

773
774
775
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

776
        # Trim lists to the batch size.
777
778
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
779
        del self.spec_token_ids[num_reqs:]
780

781
    def refresh_metadata(self):
782
        """Apply any batch updates to sampling metadata."""
783

784
        if self.is_pooling_model:
785
786
787
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
788
789
790
791
792
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
793
794
795
796
797
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
798
799
800

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
801
        if not self.all_greedy:
802
803
804
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
805
806
        else:
            temperature = None
807
808
809
810
811
812
813
814
815
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
816
817
818
819
820
821
822
823
824
825
826
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
827

828
829
        needs_prompt_token_ids = (
            not self.no_penalties
830
831
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
832
833
834
835
836
837
838
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
839

840
841
842
843
844
845
846
847
848
849
850
851
852
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

853
        allowed_token_ids_mask: torch.Tensor | None = None
854
855
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
856
857
858
859
860
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
861
862
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

863
        return SamplingMetadata(
864
            temperature=temperature,
865
866
            all_greedy=self.all_greedy,
            all_random=self.all_random,
867
868
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
869
870
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
871
872
873
874
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
875
            output_token_ids=output_token_ids,
876
            spec_token_ids=self.spec_token_ids,
877
            no_penalties=self.no_penalties,
878
            allowed_token_ids_mask=allowed_token_ids_mask,
879
            bad_words_token_ids=self.bad_words_token_ids,
880
            logitsprocs=self.logitsprocs,
881
882
        )

883
884
885
886
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

887
888
889
890
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

891
892
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
893
        pooling_states = self.get_pooling_states()
894
895

        return PoolingMetadata(
896
            prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
897
898
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
899
            pooling_states=pooling_states,
900
901
        )

902
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
903
904
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
905
906
907
908
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
909
910
            pin_memory=self.pin_memory,
        )
911
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
912
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
913
914
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
915
        for i in range(num_reqs):
916
917
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
918

919
    def make_lora_inputs(
920
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
921
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
922
923
924
925
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
926
927
928
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
929
930
931
932
933
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

934
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
935
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
936
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
937

938
        active_lora_requests: set[LoRARequest] = set(
939
940
            self.lora_id_to_lora_request.values()
        )
941
942
943

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

944
945
946
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
947
        async_copy_ready_event: torch.Event,
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
986
987
988
989
990
991
992
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
993
994
995
            # output placeholders (tokens can be discarded after kv-load
            # failure) or a larger number (async spec decode adds optimistic
            # placeholders that may exceed the actual acceptance count).
996
997
998
999
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
1000
1001
            req_output_token_ids[first_placeholder:] = new_ids
            # ^ Implicitly resizes to (first_placeholder + num_to_replace)
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
1022

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1043
1044
    @property
    def no_penalties(self) -> bool:
1045
1046
1047
1048
1049
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1050

1051
    @property
1052
    def max_num_logprobs(self) -> int | None:
1053
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1054

1055
1056
1057
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0