gpu_input_batch.py 39.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
26
from vllm.v1.utils import copy_slice
27
from vllm.v1.worker.block_table import MultiGroupBlockTable
28
29
30
31
32


@dataclass
class CachedRequestState:
    req_id: str
33
    prompt_token_ids: list[int] | None
34
    mm_features: list[MultiModalFeatureSpec]
35
36
37
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
42

43
44
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
45

46
47
    xdrope_positions: torch.Tensor | None = None

48
49
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
50

51
52
53
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

54
    def __post_init__(self):
55
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
56
57
            self.prompt_token_ids, self.prompt_embeds
        )
58

59
60
    @property
    def num_tokens(self) -> int:
61
62
63
64
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
65
66
67
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
68
69
                    "provided via prompt_embeds, and its ID is unknown."
                )
70
            return self.prompt_token_ids[idx]
71
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
72
            return self.output_token_ids[idx - self.num_prompt_tokens]
73
        return -1
74
75
76
77


class InputBatch:
    def __init__(
78
79
80
81
82
83
84
85
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
86
        kernel_block_sizes: list[int],
87
        logitsprocs: LogitsProcessors | None = None,
88
        logitsprocs_need_output_token_ids: bool = False,
89
        is_spec_decode: bool = False,
90
        is_pooling_model: bool = False,
91
        num_speculative_tokens: int = 0,
92
        cp_kv_cache_interleave_size: int = 1,
93
    ):
94
        self.is_pooling_model = is_pooling_model
95
        self.is_spec_decode = is_spec_decode
96
97
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
98
        self.max_num_batched_tokens = max_num_batched_tokens
99
100
        self.device = device
        self.pin_memory = pin_memory
101
        self.vocab_size = vocab_size
102

103
        self._req_ids: list[str | None] = []
104
        self.req_id_to_index: dict[str, int] = {}
105

106
107
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
108
109
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
110
111
112
113
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
114
            pin_memory=False,
115
116
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
117
        self.is_token_ids_tensor = torch.zeros(
118
119
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
120
        self.is_token_ids = self.is_token_ids_tensor.numpy()
121
122
123
124
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
125
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
126
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
127
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
128
        self.num_computed_tokens_cpu_tensor = torch.zeros(
129
            (max_num_reqs,),
130
131
132
133
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
134
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
135

136
        # Block table.
137
        self.block_table = MultiGroupBlockTable(
138
            max_num_reqs=max_num_reqs,
139
            max_model_len=max_model_len,
140
            max_num_batched_tokens=max_num_batched_tokens,
141
            pin_memory=pin_memory,
142
            device=device,
143
            block_sizes=block_sizes,
144
            kernel_block_sizes=kernel_block_sizes,
145
            num_speculative_tokens=num_speculative_tokens,
146
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
147
148
149
        )

        # Sampling-related.
150
151
152
153
154
155
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
156
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
157
158
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
159

160
161
162
163
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
164
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
165
        self.top_p_reqs: set[str] = set()
166

167
168
169
170
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
171
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
172
        self.top_k_reqs: set[str] = set()
173

174
175
        # IDs of requests which do not support spec decoding
        self.spec_decode_unsupported_reqs: set[str] = set()
176

177
        # Frequency penalty related data structures
178
179
180
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
181
        self.frequency_penalties_cpu_tensor = torch.empty(
182
183
184
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
185
        self.frequency_penalties_reqs: set[str] = set()
186
187

        # Presence penalty related data structures
188
189
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
190
        )
191
192
193
194
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
195
        self.presence_penalties_reqs: set[str] = set()
196
197

        # Repetition penalty related data structures
198
199
200
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
201
        self.repetition_penalties_cpu_tensor = torch.empty(
202
203
204
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
205
        self.repetition_penalties_reqs: set[str] = set()
206

207
        # Speculative decoding
208
209
210
211
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
212

213
        # lora related
214
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
215
216
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
217

218
        # req_index -> generator
219
220
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
221
        self.generators: dict[int, torch.Generator] = {}
222

223
        self.num_logprobs: dict[str, int] = {}
224

225
226
227
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

228
229
230
231
232
233
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
234
        self.has_allowed_token_ids: set[str] = set()
235
236
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
237
238
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
239

240
241
242
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

243
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
244

245
        self.req_output_token_ids: list[list[int] | None] = []
246

247
248
249
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
250
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
251

252
        # Store last speculative tokens for sampler.
253
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
254

255
256
257
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

258
259
        self.pooling_params: dict[str, PoolingParams] = {}

260
        # Cached reference to the GPU tensor of previously sampled tokens
261
262
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
263
264
265
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
266
        self.sampled_token_ids_cpu: torch.Tensor | None = None
267
        self.async_copy_ready_event: torch.Event | None = None
268

269
    @property
270
    def req_ids(self) -> list[str]:
271
272
        # None elements should only be present transiently
        # while performing state updates to the batch.
273
        return cast(list[str], self._req_ids)
274

275
    def _register_add_request(self, request: "CachedRequestState") -> int:
276
277
278
279
280
281
282
283
284
285
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
286
287
288
289
290
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
291
292
293
294
295
296
297
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
298

299
        return new_req_index
300

301
302
303
    def add_request(
        self,
        request: "CachedRequestState",
304
    ) -> int:
305
        req_index = self._register_add_request(request)
306
307

        req_id = request.req_id
308
309
310
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
311
            self.spec_token_ids.append([])
312
313
314
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
315
            self.spec_token_ids[req_index].clear()
316

317
318
319
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
320
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
321
322
            request.prompt_token_ids, request.prompt_embeds
        )
323
        self.num_prompt_tokens[req_index] = num_prompt_tokens
324
325
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
326
        if request.prompt_token_ids is not None:
327
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
328
329
330
331
332
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
333
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
334
335
        self.is_token_ids[req_index, start_idx:end_idx] = True
        # Number of token ids in prompt (token_ids_cpu or prompt_embeds).
336
        # NOTE(woosuk): This may include spec decode tokens.
337
        self.num_tokens[req_index] = request.num_tokens
338
339
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
340
341

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
342
        self.block_table.add_row(request.block_ids, req_index)
343

344
        if sampling_params := request.sampling_params:
345
            if self.is_spec_decode and is_spec_decode_unsupported(sampling_params):
346
                self.spec_decode_unsupported_reqs.add(req_id)
347
            if sampling_params.sampling_type == SamplingType.GREEDY:
348
349
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
350
351
352
353
354
355
356
357
358
359
360
361
362
363
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
364
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
365
366
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
367
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
368
369
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
370
371
372
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
373
374
375
376
377
378
379
380
381
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
382
383
384
385
386
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
387
388
389
390
391
392
393
394
395
396

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
397
398
                        device=self.device,
                    )
399
400
401
402
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
403
404
                        device="cpu",
                    )
405
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
406
                # False means we don't fill with -inf.
407
                self.allowed_token_ids_mask_cpu_tensor[req_index][
408
409
                    sampling_params.allowed_token_ids
                ] = False
410

411
            if sampling_params.bad_words_token_ids:
412
413
414
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
415
416
417
        elif pooling_params := request.pooling_params:
            self.pooling_params[req_id] = pooling_params
            self.logits_processing_needs_token_ids[req_index] = (
418
419
                pooling_params.requires_token_ids
            )
420
        else:
421
            raise NotImplementedError("Unrecognized request type")
422

423
424
425
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

426
427
428
429
430
431
432
433
434
435
436
437
438
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

439
440
        return req_index

441
    def remove_request(self, req_id: str) -> int | None:
442
        """This method must always be followed by a call to condense().
443

444
445
446
447
448
449
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
450

451
452
453
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
454
455

        self.batch_update_builder.removed_append(req_index)
456
457
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
458
        self.spec_token_ids[req_index].clear()
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
            return req_index

474
475
476
477
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
478
        self.spec_decode_unsupported_reqs.discard(req_id)
479
480
481
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
482
483
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
484
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
485

486
487
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
488
            # False means we don't fill with -inf.
489
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
490
        self.bad_words_token_ids.pop(req_index, None)
491
492
        return req_index

493
494
495
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
496
497
498
499
500
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
501
502
503
504
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
505
        assert old_id_i1 is not None and old_id_i2 is not None
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens[i1], self.num_tokens[i2] = (
            self.num_tokens[i2],
            self.num_tokens[i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
526

527
528
529
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
530
        # instead, we need to temporarily copy the data for one of the indices
531
532
533
534
535
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

536
537
538
539
540
541
542
543
544
545
546
547
548
549
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

550
        self.block_table.swap_row(i1, i2)
551

552
553
554
555
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
556

557
558
559
560
561
562
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
587
588
589
590

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

591
        if self.allowed_token_ids_mask_cpu_tensor is not None:
592
593
594
595
596
597
598
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
599

600
601
602
603
604
605
606
607
608
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
609
        """
610
611
        num_reqs = self.num_reqs

612
613
614
615
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
616
        if num_reqs == 0:
617
            # The batched states are empty.
618
619
            self._req_ids.clear()
            self.req_output_token_ids.clear()
620
            self.spec_token_ids.clear()
621
622
623
624
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
625
        last_req_index = num_reqs + len(empty_req_indices) - 1
626
627
628
629
630
631
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
632
633
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
634
635
636
            if empty_index >= last_req_index:
                break

637
638
639
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
640
641
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
642
            assert req_id is not None
643
644
645
646
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
647
648
            self.req_id_to_index[req_id] = empty_index

649
650
651
652
653
654
655
656
657
            if last_req_index != empty_index:
                (
                    self.spec_token_ids[last_req_index],
                    self.spec_token_ids[empty_index],
                ) = (
                    self.spec_token_ids[empty_index],
                    self.spec_token_ids[last_req_index],
                )
                self.spec_token_ids[last_req_index].clear()
658

659
660
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
661
662
                last_req_index, :num_tokens
            ]
663
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
664
665
                last_req_index, :num_tokens
            ]
666
            if last_req_index in self.req_prompt_embeds:
667
668
669
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
670
            self.num_tokens[empty_index] = num_tokens
671
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
672
673
674
675
676
677
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
678
            self.block_table.move_row(last_req_index, empty_index)
679
680

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
681
682
                last_req_index
            ]
683
684
685

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
686
                # Sampling state not used by pooling models.
687
688
689
690
691
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
692
693
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
694

695
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
696
697
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
698
699
700
701
702
703
704
705
706
707
708
709
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
710
711
712
713
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

714
            # TODO convert these to LogitsProcessors
715
            if self.allowed_token_ids_mask_cpu_tensor is not None:
716
717
718
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
719

720
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
721
722
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
723

724
725
726
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

727
        # Trim lists to the batch size.
728
729
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
730
        del self.spec_token_ids[num_reqs:]
731

732
    def refresh_metadata(self):
733
        """Apply any batch updates to sampling metadata."""
734

735
        if self.is_pooling_model:
736
737
738
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
739
740
741
742
743
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
744
745
746
747
748
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
749
750
751

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
752
        if not self.all_greedy:
753
754
755
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
756
757
        else:
            temperature = None
758
759
760
761
762
763
764
765
766
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
767
768
769
770
771
772
773
774
775
776
777
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
778

779
780
        needs_prompt_token_ids = (
            not self.no_penalties
781
782
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
783
784
785
786
787
788
789
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
790

791
792
793
794
795
796
797
798
799
800
801
802
803
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

804
        allowed_token_ids_mask: torch.Tensor | None = None
805
806
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
807
808
809
810
811
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
812
813
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

814
        return SamplingMetadata(
815
            temperature=temperature,
816
817
            all_greedy=self.all_greedy,
            all_random=self.all_random,
818
819
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
820
821
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
822
823
824
825
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
826
            output_token_ids=output_token_ids,
827
            spec_token_ids=cast(list[list[int]], self.spec_token_ids),
828
            no_penalties=self.no_penalties,
829
            allowed_token_ids_mask=allowed_token_ids_mask,
830
            bad_words_token_ids=self.bad_words_token_ids,
831
            logitsprocs=self.logitsprocs,
832
833
        )

834
835
836
837
838
839
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
840
841

        return PoolingMetadata(
842
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
843
844
845
846
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
        )

847
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
848
849
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
850
851
852
853
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
854
855
            pin_memory=self.pin_memory,
        )
856
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
857
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
858
859
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
860
        for i in range(num_reqs):
861
862
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
863

864
    def make_lora_inputs(
865
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
866
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
867
868
869
870
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
871
872
873
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
874
875
876
877
878
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

879
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
880
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
881
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
882

883
        active_lora_requests: set[LoRARequest] = set(
884
885
            self.lora_id_to_lora_request.values()
        )
886
887
888

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

889
890
891
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
892
        async_copy_ready_event: torch.Event,
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
                sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
            # Replace placeholder token id with actual sampled id.
            req_output_token_ids[-1] = sampled_token_ids[prev_index]

935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

955
956
    @property
    def no_penalties(self) -> bool:
957
958
959
960
961
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
962

963
    @property
964
    def max_num_logprobs(self) -> int | None:
965
        return max(self.num_logprobs.values()) if self.num_logprobs else None
966

967
968
969
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0