gpu_input_batch.py 42.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.utils import copy_slice
26
from vllm.v1.worker.block_table import MultiGroupBlockTable
27
28
29
30
31


@dataclass
class CachedRequestState:
    req_id: str
32
    prompt_token_ids: list[int] | None
33
    mm_features: list[MultiModalFeatureSpec]
34
35
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
36

37
    block_ids: tuple[list[int], ...]
38
    num_computed_tokens: int
39
    output_token_ids: list[int]
40

41
42
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
43

44
45
    xdrope_positions: torch.Tensor | None = None

46
47
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
48

49
50
51
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

52
53
54
55
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

56
    def __post_init__(self):
57
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
58
59
            self.prompt_token_ids, self.prompt_embeds
        )
60

61
62
63
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

64
65
    @property
    def num_tokens(self) -> int:
66
67
68
69
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
70
71
72
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
73
74
                    "provided via prompt_embeds, and its ID is unknown."
                )
75
            return self.prompt_token_ids[idx]
76
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
77
            return self.output_token_ids[idx - self.num_prompt_tokens]
78
        return -1
79
80
81
82


class InputBatch:
    def __init__(
83
84
85
86
87
88
89
90
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
91
        kernel_block_sizes: list[int],
92
        logitsprocs: LogitsProcessors | None = None,
93
        logitsprocs_need_output_token_ids: bool = False,
94
        is_spec_decode: bool = False,
95
        is_pooling_model: bool = False,
96
        num_speculative_tokens: int = 0,
97
        cp_kv_cache_interleave_size: int = 1,
98
    ):
99
        self.is_pooling_model = is_pooling_model
100
        self.is_spec_decode = is_spec_decode
101
102
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
103
        self.max_num_batched_tokens = max_num_batched_tokens
104
105
        self.device = device
        self.pin_memory = pin_memory
106
        self.vocab_size = vocab_size
107

108
        self._req_ids: list[str | None] = []
109
        self.req_id_to_index: dict[str, int] = {}
110

111
112
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
113
114
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
115
116
117
118
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
119
            pin_memory=False,
120
121
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
122
        self.is_token_ids_tensor = torch.zeros(
123
124
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
125
        self.is_token_ids = self.is_token_ids_tensor.numpy()
126
127
128
129
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
130
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
131
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
132
        self.num_computed_tokens_cpu_tensor = torch.zeros(
133
            (max_num_reqs,),
134
135
136
137
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
138
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
139

140
        # Block table.
141
        self.block_table = MultiGroupBlockTable(
142
            max_num_reqs=max_num_reqs,
143
            max_model_len=max_model_len,
144
            max_num_batched_tokens=max_num_batched_tokens,
145
            pin_memory=pin_memory,
146
            device=device,
147
            block_sizes=block_sizes,
148
            kernel_block_sizes=kernel_block_sizes,
149
            num_speculative_tokens=num_speculative_tokens,
150
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
151
152
153
        )

        # Sampling-related.
154
155
156
157
158
159
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
160
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
161
162
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
163

164
165
166
167
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
168
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
169
        self.top_p_reqs: set[str] = set()
170

171
172
173
174
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
175
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
176
        self.top_k_reqs: set[str] = set()
177

178
        # Frequency penalty related data structures
179
180
181
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
182
        self.frequency_penalties_cpu_tensor = torch.empty(
183
184
185
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
186
        self.frequency_penalties_reqs: set[str] = set()
187
188

        # Presence penalty related data structures
189
190
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
191
        )
192
193
194
195
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
196
        self.presence_penalties_reqs: set[str] = set()
197
198

        # Repetition penalty related data structures
199
200
201
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
202
        self.repetition_penalties_cpu_tensor = torch.empty(
203
204
205
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
206
        self.repetition_penalties_reqs: set[str] = set()
207

208
        # Speculative decoding
209
210
211
212
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
213

214
        # lora related
215
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
216
217
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
218

219
        # req_index -> generator
220
221
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
222
        self.generators: dict[int, torch.Generator] = {}
223

224
        self.num_logprobs: dict[str, int] = {}
225

226
227
228
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

229
230
231
232
233
234
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
235
        self.has_allowed_token_ids: set[str] = set()
236
237
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
238
239
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
240

241
242
243
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

244
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
245

246
        self.req_output_token_ids: list[list[int] | None] = []
247

248
249
250
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
251
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
252

253
        # Store last speculative tokens for sampler.
254
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
255

256
257
258
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

259
        # for pooling models
260
        self.pooling_params: dict[str, PoolingParams] = {}
261
        self.pooling_states: dict[str, PoolingStates] = {}
262

263
        # Cached reference to the GPU tensor of previously sampled tokens
264
265
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
266
267
268
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
269
        self.sampled_token_ids_cpu: torch.Tensor | None = None
270
        self.async_copy_ready_event: torch.Event | None = None
271

272
    @property
273
    def req_ids(self) -> list[str]:
274
275
        # None elements should only be present transiently
        # while performing state updates to the batch.
276
        return cast(list[str], self._req_ids)
277

278
    def _register_add_request(self, request: "CachedRequestState") -> int:
279
280
281
282
283
284
285
286
287
288
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
289
290
291
292
293
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
294
295
296
297
298
299
300
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
301

302
        return new_req_index
303

304
305
306
    def add_request(
        self,
        request: "CachedRequestState",
307
    ) -> int:
308
        req_index = self._register_add_request(request)
309
310

        req_id = request.req_id
311
312
313
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
314
            self.spec_token_ids.append([])
315
316
317
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
318
            self.spec_token_ids[req_index].clear()
319

320
321
322
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
323
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
324
325
            request.prompt_token_ids, request.prompt_embeds
        )
326
        self.num_prompt_tokens[req_index] = num_prompt_tokens
327
328
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
329
        if request.prompt_token_ids is not None:
330
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
331
332
333
334
335
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
336
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
337
        self.is_token_ids[req_index, start_idx:end_idx] = True
338
339
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
340
341

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
342
        self.block_table.add_row(request.block_ids, req_index)
343

344
345
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
346
347
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
348
349
350
351
352
353
354
355
356
357
358
359
360
361
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
362
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
363
364
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
365
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
366
367
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
368
369
370
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
371
372
373
374
375
376
377
378
379
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
380
381
382
383
384
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
385
386
387
388
389
390
391
392
393
394

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
395
396
                        device=self.device,
                    )
397
398
399
400
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
401
402
                        device="cpu",
                    )
403
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
404
                # False means we don't fill with -inf.
405
                self.allowed_token_ids_mask_cpu_tensor[req_index][
406
407
                    sampling_params.allowed_token_ids
                ] = False
408

409
            if sampling_params.bad_words_token_ids:
410
411
412
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
413
        elif pooling_params := request.pooling_params:
414
415
416
            pooling_states = request.pooling_states
            assert pooling_states is not None

417
            self.pooling_params[req_id] = pooling_params
418
            self.pooling_states[req_id] = pooling_states
419
            self.logits_processing_needs_token_ids[req_index] = (
420
421
                pooling_params.requires_token_ids
            )
422
        else:
423
            raise NotImplementedError("Unrecognized request type")
424

425
426
427
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

428
429
430
431
432
433
434
435
436
437
438
439
440
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

441
442
        return req_index

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

469
    def remove_request(self, req_id: str) -> int | None:
470
        """This method must always be followed by a call to condense().
471

472
473
474
475
476
477
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
478

479
480
481
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
482
483

        self.batch_update_builder.removed_append(req_index)
484
485
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
486
        self.spec_token_ids[req_index].clear()
487

488
489
490
491
492
493
494
495
496
497
498
499
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
500
            self.pooling_states.pop(req_id, None)
501
502
            return req_index

503
504
505
506
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
507
508
509
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
510
511
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
512
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
513
514
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
515

516
517
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
518
            # False means we don't fill with -inf.
519
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
520
        self.bad_words_token_ids.pop(req_index, None)
521
522
        return req_index

523
524
525
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
526
527
528
529
530
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
531
532
533
534
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
535
        assert old_id_i1 is not None and old_id_i2 is not None
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
552

553
554
555
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
556
        # instead, we need to temporarily copy the data for one of the indices
557
558
559
560
561
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

562
563
564
565
566
567
568
569
570
571
572
573
574
575
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

576
        self.block_table.swap_row(i1, i2)
577

578
579
580
581
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
582

583
584
585
586
587
588
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
613
614
615
616

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

617
        if self.allowed_token_ids_mask_cpu_tensor is not None:
618
619
620
621
622
623
624
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
625

626
627
628
629
630
631
632
633
634
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
635
        """
636
637
        num_reqs = self.num_reqs

638
639
640
641
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
642
        if num_reqs == 0:
643
            # The batched states are empty.
644
645
            self._req_ids.clear()
            self.req_output_token_ids.clear()
646
            self.spec_token_ids.clear()
647
648
649
650
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
651
        last_req_index = num_reqs + len(empty_req_indices) - 1
652
653
654
655
656
657
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
658
659
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
660
661
662
            if empty_index >= last_req_index:
                break

663
664
665
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
666
667
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
668
            assert req_id is not None
669
670
671
672
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
673
674
            self.req_id_to_index[req_id] = empty_index

675
676
677
678
679
680
681
682
683
            num_tokens = self.num_tokens_no_spec[last_req_index] + len(
                self.spec_token_ids[last_req_index]
            )

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
684

685
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
686
687
                last_req_index, :num_tokens
            ]
688
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
689
690
                last_req_index, :num_tokens
            ]
691
            if last_req_index in self.req_prompt_embeds:
692
693
694
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
695
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
696
697
698
699
700
701
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
702
            self.block_table.move_row(last_req_index, empty_index)
703
704

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
705
706
                last_req_index
            ]
707
708
709

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
710
                # Sampling state not used by pooling models.
711
712
713
714
715
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
716
717
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
718

719
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
720
721
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
722
723
724
725
726
727
728
729
730
731
732
733
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
734
735
736
737
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

738
            # TODO convert these to LogitsProcessors
739
            if self.allowed_token_ids_mask_cpu_tensor is not None:
740
741
742
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
743

744
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
745
746
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
747

748
749
750
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

751
        # Trim lists to the batch size.
752
753
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
754
        del self.spec_token_ids[num_reqs:]
755

756
    def refresh_metadata(self):
757
        """Apply any batch updates to sampling metadata."""
758

759
        if self.is_pooling_model:
760
761
762
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
763
764
765
766
767
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
768
769
770
771
772
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
773
774
775

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
776
        if not self.all_greedy:
777
778
779
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
780
781
        else:
            temperature = None
782
783
784
785
786
787
788
789
790
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
791
792
793
794
795
796
797
798
799
800
801
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
802

803
804
        needs_prompt_token_ids = (
            not self.no_penalties
805
806
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
807
808
809
810
811
812
813
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
814

815
816
817
818
819
820
821
822
823
824
825
826
827
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

828
        allowed_token_ids_mask: torch.Tensor | None = None
829
830
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
831
832
833
834
835
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
836
837
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

838
        return SamplingMetadata(
839
            temperature=temperature,
840
841
            all_greedy=self.all_greedy,
            all_random=self.all_random,
842
843
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
844
845
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
846
847
848
849
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
850
            output_token_ids=output_token_ids,
851
            spec_token_ids=self.spec_token_ids,
852
            no_penalties=self.no_penalties,
853
            allowed_token_ids_mask=allowed_token_ids_mask,
854
            bad_words_token_ids=self.bad_words_token_ids,
855
            logitsprocs=self.logitsprocs,
856
857
        )

858
859
860
861
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

862
863
864
865
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

866
867
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
868
        pooling_states = self.get_pooling_states()
869
870

        return PoolingMetadata(
871
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
872
873
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
874
            pooling_states=pooling_states,
875
876
        )

877
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
878
879
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
880
881
882
883
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
884
885
            pin_memory=self.pin_memory,
        )
886
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
887
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
888
889
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
890
        for i in range(num_reqs):
891
892
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
893

894
    def make_lora_inputs(
895
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
896
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
897
898
899
900
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
901
902
903
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
904
905
906
907
908
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

909
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
910
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
911
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
912

913
        active_lora_requests: set[LoRARequest] = set(
914
915
            self.lora_id_to_lora_request.values()
        )
916
917
918

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

919
920
921
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
922
        async_copy_ready_event: torch.Event,
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
            # output placeholders (tokens can be discarded after a kv-load failure).
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
            end_index = first_placeholder + num_to_replace
            req_output_token_ids[first_placeholder:end_index] = new_ids

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
995

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1016
1017
    @property
    def no_penalties(self) -> bool:
1018
1019
1020
1021
1022
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1023

1024
    @property
1025
    def max_num_logprobs(self) -> int | None:
1026
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1027

1028
1029
1030
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0