gpu_input_batch.py 39.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
26
from vllm.v1.utils import copy_slice
27
from vllm.v1.worker.block_table import MultiGroupBlockTable
28
29
30
31
32


@dataclass
class CachedRequestState:
    req_id: str
33
    prompt_token_ids: list[int] | None
34
    mm_features: list[MultiModalFeatureSpec]
35
36
37
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
42

43
44
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
45

46
47
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
48

49
    def __post_init__(self):
50
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
51
52
            self.prompt_token_ids, self.prompt_embeds
        )
53

54
55
    @property
    def num_tokens(self) -> int:
56
57
58
59
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
60
61
62
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
63
64
                    "provided via prompt_embeds, and its ID is unknown."
                )
65
            return self.prompt_token_ids[idx]
66
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
67
            return self.output_token_ids[idx - self.num_prompt_tokens]
68
        return -1
69
70
71
72


class InputBatch:
    def __init__(
73
74
75
76
77
78
79
80
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
81
        kernel_block_sizes: list[int],
82
        logitsprocs: LogitsProcessors | None = None,
83
        logitsprocs_need_output_token_ids: bool = False,
84
        is_spec_decode: bool = False,
85
        is_pooling_model: bool = False,
86
        num_speculative_tokens: int = 0,
87
        dcp_kv_cache_interleave_size: int = 1,
88
    ):
89
        self.is_pooling_model = is_pooling_model
90
        self.is_spec_decode = is_spec_decode
91
92
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
93
        self.max_num_batched_tokens = max_num_batched_tokens
94
95
        self.device = device
        self.pin_memory = pin_memory
96
        self.vocab_size = vocab_size
97

98
        self._req_ids: list[str | None] = []
99
        self.req_id_to_index: dict[str, int] = {}
100

101
102
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
103
104
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
105
106
107
108
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
109
            pin_memory=False,
110
111
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
112
        self.is_token_ids_tensor = torch.zeros(
113
114
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
115
        self.is_token_ids = self.is_token_ids_tensor.numpy()
116
117
118
119
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
120
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
121
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
122
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
123
        self.num_computed_tokens_cpu_tensor = torch.zeros(
124
            (max_num_reqs,),
125
126
127
128
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
129
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
130

131
        # Block table.
132
        self.block_table = MultiGroupBlockTable(
133
            max_num_reqs=max_num_reqs,
134
            max_model_len=max_model_len,
135
            max_num_batched_tokens=max_num_batched_tokens,
136
            pin_memory=pin_memory,
137
            device=device,
138
            block_sizes=block_sizes,
139
            kernel_block_sizes=kernel_block_sizes,
140
            num_speculative_tokens=num_speculative_tokens,
141
            dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
142
143
144
        )

        # Sampling-related.
145
146
147
148
149
150
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
151
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
152
153
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
154

155
156
157
158
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
159
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
160
        self.top_p_reqs: set[str] = set()
161

162
163
164
165
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
166
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
167
        self.top_k_reqs: set[str] = set()
168

169
170
        # IDs of requests which do not support spec decoding
        self.spec_decode_unsupported_reqs: set[str] = set()
171

172
        # Frequency penalty related data structures
173
174
175
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
176
        self.frequency_penalties_cpu_tensor = torch.empty(
177
178
179
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
180
        self.frequency_penalties_reqs: set[str] = set()
181
182

        # Presence penalty related data structures
183
184
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
185
        )
186
187
188
189
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
190
        self.presence_penalties_reqs: set[str] = set()
191
192

        # Repetition penalty related data structures
193
194
195
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
196
        self.repetition_penalties_cpu_tensor = torch.empty(
197
198
199
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
200
        self.repetition_penalties_reqs: set[str] = set()
201

202
        # Speculative decoding
203
204
205
206
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
207

208
        # lora related
209
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
210
211
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
212

213
        # req_index -> generator
214
215
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
216
        self.generators: dict[int, torch.Generator] = {}
217

218
        self.num_logprobs: dict[str, int] = {}
219
220
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
221
        self.num_prompt_logprobs: dict[str, int] = {}
222

223
224
225
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

226
227
228
229
230
231
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
232
        self.has_allowed_token_ids: set[str] = set()
233
234
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
235
236
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
237

238
239
240
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

241
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
242

243
        self.req_output_token_ids: list[list[int] | None] = []
244

245
246
247
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
248
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
249

250
        # Store last speculative tokens for sampler.
251
        self.spec_token_ids: list[list[int] | None] = []
252

253
254
255
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

256
257
        self.pooling_params: dict[str, PoolingParams] = {}

258
        # Cached reference to the GPU tensor of previously sampled tokens
259
260
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
261
262
263
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
264
265
        self.sampled_token_ids_cpu: torch.Tensor | None = None
        self.async_copy_ready_event: torch.cuda.Event | None = None
266

267
    @property
268
    def req_ids(self) -> list[str]:
269
270
        # None elements should only be present transiently
        # while performing state updates to the batch.
271
        return cast(list[str], self._req_ids)
272

273
    def _register_add_request(self, request: "CachedRequestState") -> int:
274
275
276
277
278
279
280
281
282
283
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
284
285
286
287
288
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
289
290
291
292
293
294
295
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
296

297
        return new_req_index
298

299
300
301
    def add_request(
        self,
        request: "CachedRequestState",
302
    ) -> int:
303
        req_index = self._register_add_request(request)
304
305

        req_id = request.req_id
306
307
308
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
309
            self.spec_token_ids.append([])
310
311
312
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
313
            self.spec_token_ids[req_index] = []
314

315
316
317
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
318
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
319
320
            request.prompt_token_ids, request.prompt_embeds
        )
321
        self.num_prompt_tokens[req_index] = num_prompt_tokens
322
323
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
324
        if request.prompt_token_ids is not None:
325
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
326
327
328
329
330
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
331
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
332
333
        self.is_token_ids[req_index, start_idx:end_idx] = True
        # Number of token ids in prompt (token_ids_cpu or prompt_embeds).
334
        # NOTE(woosuk): This may include spec decode tokens.
335
        self.num_tokens[req_index] = request.num_tokens
336
337
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
338
339

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
340
        self.block_table.add_row(request.block_ids, req_index)
341

342
        if sampling_params := request.sampling_params:
343
            if self.is_spec_decode and is_spec_decode_unsupported(sampling_params):
344
                self.spec_decode_unsupported_reqs.add(req_id)
345
            if sampling_params.sampling_type == SamplingType.GREEDY:
346
347
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
348
349
350
351
352
353
354
355
356
357
358
359
360
361
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
362
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
363
364
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
365
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
366
367
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
368
369
370
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
371
372
373
374
375
376
377
378
379
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
380
381
382
383
384
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
385
            if sampling_params.prompt_logprobs is not None:
386
                self.num_prompt_logprobs[req_id] = (
387
388
389
390
                    self.vocab_size
                    if sampling_params.prompt_logprobs == -1
                    else sampling_params.prompt_logprobs
                )
391
392
393
394
395
396
397
398
399
400

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
401
402
                        device=self.device,
                    )
403
404
405
406
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
407
408
                        device="cpu",
                    )
409
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
410
                # False means we don't fill with -inf.
411
                self.allowed_token_ids_mask_cpu_tensor[req_index][
412
413
                    sampling_params.allowed_token_ids
                ] = False
414

415
            if sampling_params.bad_words_token_ids:
416
417
418
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
419
420
421
        elif pooling_params := request.pooling_params:
            self.pooling_params[req_id] = pooling_params
            self.logits_processing_needs_token_ids[req_index] = (
422
423
                pooling_params.requires_token_ids
            )
424
        else:
425
            raise NotImplementedError("Unrecognized request type")
426

427
428
429
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

430
431
432
433
434
435
436
437
438
439
440
441
442
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

443
444
        return req_index

445
    def remove_request(self, req_id: str) -> int | None:
446
        """This method must always be followed by a call to condense().
447

448
449
450
451
452
453
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
454

455
456
457
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
458
459

        self.batch_update_builder.removed_append(req_index)
460
461
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
462
        self.spec_token_ids[req_index] = None
463

464
465
466
467
468
469
470
471
472
473
474
475
476
477
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
            return req_index

478
479
480
481
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
482
        self.spec_decode_unsupported_reqs.discard(req_id)
483
484
485
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
486
487
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
488
        self.num_prompt_logprobs.pop(req_id, None)
489
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
490

491
492
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
493
            # False means we don't fill with -inf.
494
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
495
        self.bad_words_token_ids.pop(req_index, None)
496
497
        return req_index

498
499
500
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
501
502
503
504
505
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
506
507
508
509
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
510
        assert old_id_i1 is not None and old_id_i2 is not None
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens[i1], self.num_tokens[i2] = (
            self.num_tokens[i2],
            self.num_tokens[i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
531

532
533
534
535
536
537
538
539
540
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporiarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

541
542
543
544
545
546
547
548
549
550
551
552
553
554
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

555
        self.block_table.swap_row(i1, i2)
556

557
558
559
560
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
561

562
563
564
565
566
567
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
592
593
594
595

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

596
        if self.allowed_token_ids_mask_cpu_tensor is not None:
597
598
599
600
601
602
603
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
604

605
606
607
608
609
610
611
612
613
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
614
        """
615
616
        num_reqs = self.num_reqs

617
618
619
620
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
621
        if num_reqs == 0:
622
            # The batched states are empty.
623
624
            self._req_ids.clear()
            self.req_output_token_ids.clear()
625
            self.spec_token_ids.clear()
626
627
628
629
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
630
        last_req_index = num_reqs + len(empty_req_indices) - 1
631
632
633
634
635
636
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
637
638
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
639
640
641
            if empty_index >= last_req_index:
                break

642
643
644
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
645
646
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
647
            assert req_id is not None
648
649
650
651
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
652
653
            self.req_id_to_index[req_id] = empty_index

654
655
656
657
            spec_token_ids = self.spec_token_ids[last_req_index]
            self.spec_token_ids[empty_index] = spec_token_ids
            self.spec_token_ids[last_req_index] = None

658
659
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
660
661
                last_req_index, :num_tokens
            ]
662
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
663
664
                last_req_index, :num_tokens
            ]
665
            if last_req_index in self.req_prompt_embeds:
666
667
668
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
669
            self.num_tokens[empty_index] = num_tokens
670
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
671
672
673
674
675
676
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
677
            self.block_table.move_row(last_req_index, empty_index)
678
679

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
680
681
                last_req_index
            ]
682
683
684

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
685
                # Sampling state not used by pooling models.
686
687
688
689
690
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
691
692
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
693

694
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
695
696
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
697
698
699
700
701
702
703
704
705
706
707
708
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
709
710
711
712
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

713
            # TODO convert these to LogitsProcessors
714
            if self.allowed_token_ids_mask_cpu_tensor is not None:
715
716
717
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
718

719
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
720
721
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
722

723
724
725
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

726
        # Trim lists to the batch size.
727
728
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
729
        del self.spec_token_ids[num_reqs:]
730

731
    def refresh_metadata(self):
732
        """Apply any batch updates to sampling metadata."""
733

734
        if self.is_pooling_model:
735
736
737
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
738
739
740
741
742
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
743
744
745
746
747
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
748
749
750

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
751
        if not self.all_greedy:
752
753
754
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
755
756
        else:
            temperature = None
757
758
759
760
761
762
763
764
765
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
766
767
768
769
770
771
772
773
774
775
776
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
777

778
779
        needs_prompt_token_ids = (
            not self.no_penalties
780
781
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
782
783
784
785
786
787
788
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
789

790
791
792
793
794
795
796
797
798
799
800
801
802
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

803
        allowed_token_ids_mask: torch.Tensor | None = None
804
805
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
806
807
808
809
810
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
811
812
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

813
        return SamplingMetadata(
814
            temperature=temperature,
815
816
            all_greedy=self.all_greedy,
            all_random=self.all_random,
817
818
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
819
820
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
821
822
823
824
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
825
            output_token_ids=output_token_ids,
826
            spec_token_ids=cast(list[list[int]], self.spec_token_ids),
827
            no_penalties=self.no_penalties,
828
            allowed_token_ids_mask=allowed_token_ids_mask,
829
            bad_words_token_ids=self.bad_words_token_ids,
830
            logitsprocs=self.logitsprocs,
831
832
        )

833
834
835
836
837
838
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
839
840

        return PoolingMetadata(
841
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
842
843
844
845
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
        )

846
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
847
848
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
849
850
851
852
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
853
854
            pin_memory=self.pin_memory,
        )
855
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
856
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
857
858
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
859
        for i in range(num_reqs):
860
861
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
862

863
    def make_lora_inputs(
864
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
865
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
866
867
868
869
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
870
871
872
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
873
874
875
876
877
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

878
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
879
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
880
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
881

882
        active_lora_requests: set[LoRARequest] = set(
883
884
            self.lora_id_to_lora_request.values()
        )
885
886
887

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
        async_copy_ready_event: torch.cuda.Event,
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
                sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
            # Replace placeholder token id with actual sampled id.
            req_output_token_ids[-1] = sampled_token_ids[prev_index]

934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

954
955
    @property
    def no_penalties(self) -> bool:
956
957
958
959
960
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
961

962
    @property
963
    def max_num_logprobs(self) -> int | None:
964
        return max(self.num_logprobs.values()) if self.num_logprobs else None
965
966
967

    @property
    def no_prompt_logprob(self) -> bool:
968
        return not self.num_prompt_logprobs
969
970
971
972

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0