gpu_input_batch.py 43.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.utils import copy_slice
26
from vllm.v1.worker.block_table import MultiGroupBlockTable
27
28
29
30
31


@dataclass
class CachedRequestState:
    req_id: str
32
    prompt_token_ids: list[int] | None
33
    mm_features: list[MultiModalFeatureSpec]
34
35
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
36

37
    block_ids: tuple[list[int], ...]
38
    num_computed_tokens: int
39
    output_token_ids: list[int]
40

41
42
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
43

44
45
    xdrope_positions: torch.Tensor | None = None

46
47
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
48

49
50
51
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

52
53
54
55
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

56
    def __post_init__(self):
57
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
58
59
            self.prompt_token_ids, self.prompt_embeds
        )
60

61
62
63
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

64
65
    @property
    def num_tokens(self) -> int:
66
67
68
69
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
70
71
72
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
73
74
                    "provided via prompt_embeds, and its ID is unknown."
                )
75
            return self.prompt_token_ids[idx]
76
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
77
            return self.output_token_ids[idx - self.num_prompt_tokens]
78
        return -1
79
80
81
82


class InputBatch:
    def __init__(
83
84
85
86
87
88
89
90
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
91
        kernel_block_sizes: list[int],
92
        max_num_blocks_per_req: list[int] | None = None,
93
        logitsprocs: LogitsProcessors | None = None,
94
        logitsprocs_need_output_token_ids: bool = False,
95
        is_spec_decode: bool = False,
96
        is_pooling_model: bool = False,
97
        cp_kv_cache_interleave_size: int = 1,
98
    ):
99
        self.is_pooling_model = is_pooling_model
100
        self.is_spec_decode = is_spec_decode
101
102
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
103
        self.max_num_batched_tokens = max_num_batched_tokens
104
105
        self.device = device
        self.pin_memory = pin_memory
106
        self.vocab_size = vocab_size
107

108
        self._req_ids: list[str | None] = []
109
        self.req_id_to_index: dict[str, int] = {}
110

111
112
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
113
114
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
115
116
117
118
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
119
            pin_memory=False,
120
121
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
122
        self.is_token_ids_tensor = torch.zeros(
123
124
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
125
        self.is_token_ids = self.is_token_ids_tensor.numpy()
126
127
128
129
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
130
131
132
133
134
135
136
        self.num_tokens_no_spec_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
137
138
139
140
141
142
143
        self.num_prompt_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs,),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_prompt_tokens = self.num_prompt_tokens_cpu_tensor.numpy()
144
        self.num_computed_tokens_cpu_tensor = torch.zeros(
145
            (max_num_reqs,),
146
147
148
149
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
150
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
151

152
        # Block table.
153
        self.block_table = MultiGroupBlockTable(
154
            max_num_reqs=max_num_reqs,
155
            max_model_len=max_model_len,
156
            max_num_batched_tokens=max_num_batched_tokens,
157
            pin_memory=pin_memory,
158
            device=device,
159
            block_sizes=block_sizes,
160
            kernel_block_sizes=kernel_block_sizes,
161
            max_num_blocks=max_num_blocks_per_req,
162
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
163
164
165
        )

        # Sampling-related.
166
167
168
169
170
171
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
172
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
173
174
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
175

176
177
178
179
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
180
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
181
        self.top_p_reqs: set[str] = set()
182

183
184
185
186
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
187
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
188
        self.top_k_reqs: set[str] = set()
189

190
        # Frequency penalty related data structures
191
192
193
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
194
        self.frequency_penalties_cpu_tensor = torch.empty(
195
196
197
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
198
        self.frequency_penalties_reqs: set[str] = set()
199
200

        # Presence penalty related data structures
201
202
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
203
        )
204
205
206
207
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
208
        self.presence_penalties_reqs: set[str] = set()
209
210

        # Repetition penalty related data structures
211
212
213
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
214
        self.repetition_penalties_cpu_tensor = torch.empty(
215
216
217
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
218
        self.repetition_penalties_reqs: set[str] = set()
219

220
        # Speculative decoding
221
        self.num_accepted_tokens_cpu_tensor = torch.ones(
222
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
223
224
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
225

226
        # lora related
227
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
228
229
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
230

231
        # req_index -> generator
232
233
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
234
        self.generators: dict[int, torch.Generator] = {}
235

236
        self.num_logprobs: dict[str, int] = {}
237

238
239
240
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

241
242
243
244
245
246
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
247
        self.has_allowed_token_ids: set[str] = set()
248
249
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
250
251
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
252

253
254
255
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

256
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
257

258
        self.req_output_token_ids: list[list[int] | None] = []
259

260
261
262
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
263
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
264

265
        # Store last speculative tokens for sampler.
266
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
267

268
269
270
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

271
        # for pooling models
272
        self.pooling_params: dict[str, PoolingParams] = {}
273
        self.pooling_states: dict[str, PoolingStates] = {}
274

275
        # Cached reference to the GPU tensor of previously sampled tokens
276
277
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
278
279
280
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
281
        self.sampled_token_ids_cpu: torch.Tensor | None = None
282
        self.async_copy_ready_event: torch.Event | None = None
283

284
    @property
285
    def req_ids(self) -> list[str]:
286
287
        # None elements should only be present transiently
        # while performing state updates to the batch.
288
        return cast(list[str], self._req_ids)
289

290
    def _register_add_request(self, request: "CachedRequestState") -> int:
291
292
293
294
295
296
297
298
299
300
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
301
302
303
304
305
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
306
307
308
309
310
311
312
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
313

314
        return new_req_index
315

316
317
318
    def add_request(
        self,
        request: "CachedRequestState",
319
    ) -> int:
320
        req_index = self._register_add_request(request)
321
322

        req_id = request.req_id
323
324
325
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
326
            self.spec_token_ids.append([])
327
328
329
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
330
            self.spec_token_ids[req_index].clear()
331

332
333
334
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
335
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
336
337
            request.prompt_token_ids, request.prompt_embeds
        )
338
        self.num_prompt_tokens[req_index] = num_prompt_tokens
339
340
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
341
        if request.prompt_token_ids is not None:
342
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
343
344
345
346
347
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
348
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
349
        self.is_token_ids[req_index, start_idx:end_idx] = True
350
351
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
352
353

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
354
        self.block_table.add_row(request.block_ids, req_index)
355

356
357
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
358
359
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
374
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
375
376
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
377
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
378
379
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
380
381
382
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
383
384
385
386
387
388
389
390
391
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
392
393
394
395
396
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
397
398
399
400
401
402
403
404
405
406

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
407
408
                        device=self.device,
                    )
409
410
411
412
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
413
414
                        device="cpu",
                    )
415
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
416
                # False means we don't fill with -inf.
417
                self.allowed_token_ids_mask_cpu_tensor[req_index][
418
419
                    sampling_params.allowed_token_ids
                ] = False
420

421
            if sampling_params.bad_words_token_ids:
422
423
424
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
425
        elif pooling_params := request.pooling_params:
426
427
428
            pooling_states = request.pooling_states
            assert pooling_states is not None

429
            self.pooling_params[req_id] = pooling_params
430
            self.pooling_states[req_id] = pooling_states
431
            self.logits_processing_needs_token_ids[req_index] = (
432
433
                pooling_params.requires_token_ids
            )
434
        else:
435
            raise NotImplementedError("Unrecognized request type")
436

437
438
439
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

440
441
442
443
444
445
446
447
448
449
450
451
452
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

453
454
        return req_index

455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

481
    def remove_request(self, req_id: str) -> int | None:
482
        """This method must always be followed by a call to condense().
483

484
485
486
487
488
489
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
490

491
492
493
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
494
495

        self.batch_update_builder.removed_append(req_index)
496
497
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
498
        self.spec_token_ids[req_index].clear()
499

500
501
502
503
504
505
506
507
508
509
510
511
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
512
            self.pooling_states.pop(req_id, None)
513
514
            return req_index

515
516
517
518
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
519
520
521
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
522
523
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
524
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
525
526
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
527

528
529
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
530
            # False means we don't fill with -inf.
531
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
532
        self.bad_words_token_ids.pop(req_index, None)
533
534
        return req_index

535
536
537
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
538
539
540
541
542
543
        # Only swap the active token prefix for each request. Copying full
        # max_model_len rows is expensive and unnecessary during reordering.
        i1_active_token_count = self._get_active_token_count(i1)
        i2_active_token_count = self._get_active_token_count(i2)
        max_active_token_count = max(i1_active_token_count, i2_active_token_count)

544
545
546
547
548
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
549
550
551
552
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
553
        assert old_id_i1 is not None and old_id_i2 is not None
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
570

571
572
573
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
574
        # instead, we need to temporarily copy the data for one of the indices
575
576
577
578
579
        tmp_token_ids = self.token_ids_cpu[i1, :max_active_token_count].copy()
        self.token_ids_cpu[i1, :max_active_token_count] = self.token_ids_cpu[
            i2, :max_active_token_count
        ]
        self.token_ids_cpu[i2, :max_active_token_count] = tmp_token_ids
580

581
582
583
        self.is_token_ids[[i1, i2], :max_active_token_count] = self.is_token_ids[
            [i2, i1], :max_active_token_count
        ]
584
585
586
587
588
589
590
591
592
593
594
595
596

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

597
        self.block_table.swap_row(i1, i2)
598

599
600
601
602
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
603

604
605
606
607
608
609
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
634
635
636
637

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

638
        if self.allowed_token_ids_mask_cpu_tensor is not None:
639
640
641
642
643
644
645
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
646

647
648
649
650
651
    def _get_active_token_count(self, req_index: int) -> int:
        return int(self.num_tokens_no_spec[req_index]) + len(
            self.spec_token_ids[req_index]
        )

652
653
654
655
656
657
658
659
660
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
661
        """
662
663
        num_reqs = self.num_reqs

664
665
666
667
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
668
        if num_reqs == 0:
669
            # The batched states are empty.
670
671
            self._req_ids.clear()
            self.req_output_token_ids.clear()
672
            self.spec_token_ids.clear()
673
674
675
676
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
677
        last_req_index = num_reqs + len(empty_req_indices) - 1
678
679
680
681
682
683
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
684
685
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
686
687
688
            if empty_index >= last_req_index:
                break

689
690
691
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
692
693
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
694
            assert req_id is not None
695
696
697
698
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
699
700
            self.req_id_to_index[req_id] = empty_index

701
            num_tokens = self._get_active_token_count(last_req_index)
702
703
704
705
706
707

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
708

709
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
710
711
                last_req_index, :num_tokens
            ]
712
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
713
714
                last_req_index, :num_tokens
            ]
715
            if last_req_index in self.req_prompt_embeds:
716
717
718
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
719
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
720
721
722
723
724
725
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
726
            self.block_table.move_row(last_req_index, empty_index)
727
728

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
729
730
                last_req_index
            ]
731
732
733

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
734
                # Sampling state not used by pooling models.
735
736
737
738
739
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
740
741
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
742

743
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
744
745
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
746
747
748
749
750
751
752
753
754
755
756
757
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
758
759
760
761
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

762
            # TODO convert these to LogitsProcessors
763
            if self.allowed_token_ids_mask_cpu_tensor is not None:
764
765
766
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
767

768
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
769
770
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
771

772
773
774
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

775
        # Trim lists to the batch size.
776
777
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
778
        del self.spec_token_ids[num_reqs:]
779

780
    def refresh_metadata(self):
781
        """Apply any batch updates to sampling metadata."""
782

783
        if self.is_pooling_model:
784
785
786
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
787
788
789
790
791
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
792
793
794
795
796
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
797
798
799

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
800
        if not self.all_greedy:
801
802
803
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
804
805
        else:
            temperature = None
806
807
808
809
810
811
812
813
814
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
815
816
817
818
819
820
821
822
823
824
825
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
826

827
828
        needs_prompt_token_ids = (
            not self.no_penalties
829
830
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
831
832
833
834
835
836
837
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
838

839
840
841
842
843
844
845
846
847
848
849
850
851
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

852
        allowed_token_ids_mask: torch.Tensor | None = None
853
854
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
855
856
857
858
859
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
860
861
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

862
        return SamplingMetadata(
863
            temperature=temperature,
864
865
            all_greedy=self.all_greedy,
            all_random=self.all_random,
866
867
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
868
869
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
870
871
872
873
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
874
            output_token_ids=output_token_ids,
875
            spec_token_ids=self.spec_token_ids,
876
            no_penalties=self.no_penalties,
877
            allowed_token_ids_mask=allowed_token_ids_mask,
878
            bad_words_token_ids=self.bad_words_token_ids,
879
            logitsprocs=self.logitsprocs,
880
881
        )

882
883
884
885
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

886
887
888
889
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

890
891
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
892
        pooling_states = self.get_pooling_states()
893
894

        return PoolingMetadata(
895
            prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
896
897
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
898
            pooling_states=pooling_states,
899
900
        )

901
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
902
903
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
904
905
906
907
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
908
909
            pin_memory=self.pin_memory,
        )
910
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
911
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
912
913
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
914
        for i in range(num_reqs):
915
916
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
917

918
    def make_lora_inputs(
919
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
920
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
921
922
923
924
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
925
926
927
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
928
929
930
931
932
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

933
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
934
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
935
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
936

937
        active_lora_requests: set[LoRARequest] = set(
938
939
            self.lora_id_to_lora_request.values()
        )
940
941
942

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

943
944
945
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
946
        async_copy_ready_event: torch.Event,
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
985
986
987
988
989
990
991
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
992
993
994
            # output placeholders (tokens can be discarded after kv-load
            # failure) or a larger number (async spec decode adds optimistic
            # placeholders that may exceed the actual acceptance count).
995
996
997
998
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
999
1000
            req_output_token_ids[first_placeholder:] = new_ids
            # ^ Implicitly resizes to (first_placeholder + num_to_replace)
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
1021

1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1042
1043
    @property
    def no_penalties(self) -> bool:
1044
1045
1046
1047
1048
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1049

1050
    @property
1051
    def max_num_logprobs(self) -> int | None:
1052
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1053

1054
1055
1056
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0