gpu_input_batch.py 39.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4
5

from dataclasses import dataclass
6
from typing import cast
7
8
9
10

import numpy as np
import torch

11
from vllm.lora.request import LoRARequest
12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams, SamplingType
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.utils.collection_utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata
19
20
21
22
23
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
24
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
26
from vllm.v1.utils import copy_slice
27
from vllm.v1.worker.block_table import MultiGroupBlockTable
28
29
30
31
32


@dataclass
class CachedRequestState:
    req_id: str
33
    prompt_token_ids: list[int] | None
34
    mm_features: list[MultiModalFeatureSpec]
35
36
37
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
42

43
44
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
45

46
47
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
48

49
50
51
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

52
    def __post_init__(self):
53
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
54
55
            self.prompt_token_ids, self.prompt_embeds
        )
56

57
58
    @property
    def num_tokens(self) -> int:
59
60
61
62
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
63
64
65
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
66
67
                    "provided via prompt_embeds, and its ID is unknown."
                )
68
            return self.prompt_token_ids[idx]
69
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
70
            return self.output_token_ids[idx - self.num_prompt_tokens]
71
        return -1
72
73
74
75


class InputBatch:
    def __init__(
76
77
78
79
80
81
82
83
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
84
        kernel_block_sizes: list[int],
85
        logitsprocs: LogitsProcessors | None = None,
86
        logitsprocs_need_output_token_ids: bool = False,
87
        is_spec_decode: bool = False,
88
        is_pooling_model: bool = False,
89
        num_speculative_tokens: int = 0,
90
        cp_kv_cache_interleave_size: int = 1,
91
    ):
92
        self.is_pooling_model = is_pooling_model
93
        self.is_spec_decode = is_spec_decode
94
95
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
96
        self.max_num_batched_tokens = max_num_batched_tokens
97
98
        self.device = device
        self.pin_memory = pin_memory
99
        self.vocab_size = vocab_size
100

101
        self._req_ids: list[str | None] = []
102
        self.req_id_to_index: dict[str, int] = {}
103

104
105
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
106
107
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
108
109
110
111
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
112
            pin_memory=False,
113
114
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
115
        self.is_token_ids_tensor = torch.zeros(
116
117
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
118
        self.is_token_ids = self.is_token_ids_tensor.numpy()
119
120
121
122
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
123
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
124
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
125
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
126
        self.num_computed_tokens_cpu_tensor = torch.zeros(
127
            (max_num_reqs,),
128
129
130
131
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
132
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
133

134
        # Block table.
135
        self.block_table = MultiGroupBlockTable(
136
            max_num_reqs=max_num_reqs,
137
            max_model_len=max_model_len,
138
            max_num_batched_tokens=max_num_batched_tokens,
139
            pin_memory=pin_memory,
140
            device=device,
141
            block_sizes=block_sizes,
142
            kernel_block_sizes=kernel_block_sizes,
143
            num_speculative_tokens=num_speculative_tokens,
144
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
145
146
147
        )

        # Sampling-related.
148
149
150
151
152
153
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
154
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
155
156
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
157

158
159
160
161
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
162
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
163
        self.top_p_reqs: set[str] = set()
164

165
166
167
168
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
169
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
170
        self.top_k_reqs: set[str] = set()
171

172
173
        # IDs of requests which do not support spec decoding
        self.spec_decode_unsupported_reqs: set[str] = set()
174

175
        # Frequency penalty related data structures
176
177
178
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
179
        self.frequency_penalties_cpu_tensor = torch.empty(
180
181
182
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
183
        self.frequency_penalties_reqs: set[str] = set()
184
185

        # Presence penalty related data structures
186
187
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
188
        )
189
190
191
192
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
193
        self.presence_penalties_reqs: set[str] = set()
194
195

        # Repetition penalty related data structures
196
197
198
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
199
        self.repetition_penalties_cpu_tensor = torch.empty(
200
201
202
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
203
        self.repetition_penalties_reqs: set[str] = set()
204

205
        # Speculative decoding
206
207
208
209
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
210

211
        # lora related
212
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
213
214
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
215

216
        # req_index -> generator
217
218
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
219
        self.generators: dict[int, torch.Generator] = {}
220

221
        self.num_logprobs: dict[str, int] = {}
222

223
224
225
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

226
227
228
229
230
231
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
232
        self.has_allowed_token_ids: set[str] = set()
233
234
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
235
236
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
237

238
239
240
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

241
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
242

243
        self.req_output_token_ids: list[list[int] | None] = []
244

245
246
247
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
248
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
249

250
        # Store last speculative tokens for sampler.
251
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
252

253
254
255
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

256
257
        self.pooling_params: dict[str, PoolingParams] = {}

258
        # Cached reference to the GPU tensor of previously sampled tokens
259
260
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
261
262
263
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
264
        self.sampled_token_ids_cpu: torch.Tensor | None = None
265
        self.async_copy_ready_event: torch.Event | None = None
266

267
    @property
268
    def req_ids(self) -> list[str]:
269
270
        # None elements should only be present transiently
        # while performing state updates to the batch.
271
        return cast(list[str], self._req_ids)
272

273
    def _register_add_request(self, request: "CachedRequestState") -> int:
274
275
276
277
278
279
280
281
282
283
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
284
285
286
287
288
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
289
290
291
292
293
294
295
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
296

297
        return new_req_index
298

299
300
301
    def add_request(
        self,
        request: "CachedRequestState",
302
    ) -> int:
303
        req_index = self._register_add_request(request)
304
305

        req_id = request.req_id
306
307
308
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
309
            self.spec_token_ids.append([])
310
311
312
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
313
            self.spec_token_ids[req_index].clear()
314

315
316
317
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
318
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
319
320
            request.prompt_token_ids, request.prompt_embeds
        )
321
        self.num_prompt_tokens[req_index] = num_prompt_tokens
322
323
        start_idx = num_prompt_tokens
        end_idx = start_idx + len(request.output_token_ids)
324
        if request.prompt_token_ids is not None:
325
            self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
326
327
328
329
330
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
331
        self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
332
333
        self.is_token_ids[req_index, start_idx:end_idx] = True
        # Number of token ids in prompt (token_ids_cpu or prompt_embeds).
334
        # NOTE(woosuk): This may include spec decode tokens.
335
        self.num_tokens[req_index] = request.num_tokens
336
337
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
338
339

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
340
        self.block_table.add_row(request.block_ids, req_index)
341

342
        if sampling_params := request.sampling_params:
343
            if self.is_spec_decode and is_spec_decode_unsupported(sampling_params):
344
                self.spec_decode_unsupported_reqs.add(req_id)
345
            if sampling_params.sampling_type == SamplingType.GREEDY:
346
347
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
348
349
350
351
352
353
354
355
356
357
358
359
360
361
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
362
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
363
364
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
365
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
366
367
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
368
369
370
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
371
372
373
374
375
376
377
378
379
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
380
381
382
383
384
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
385
386
387
388
389
390
391
392
393
394

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
395
396
                        device=self.device,
                    )
397
398
399
400
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
401
402
                        device="cpu",
                    )
403
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
404
                # False means we don't fill with -inf.
405
                self.allowed_token_ids_mask_cpu_tensor[req_index][
406
407
                    sampling_params.allowed_token_ids
                ] = False
408

409
            if sampling_params.bad_words_token_ids:
410
411
412
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
413
414
415
        elif pooling_params := request.pooling_params:
            self.pooling_params[req_id] = pooling_params
            self.logits_processing_needs_token_ids[req_index] = (
416
417
                pooling_params.requires_token_ids
            )
418
        else:
419
            raise NotImplementedError("Unrecognized request type")
420

421
422
423
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

424
425
426
427
428
429
430
431
432
433
434
435
436
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

437
438
        return req_index

439
    def remove_request(self, req_id: str) -> int | None:
440
        """This method must always be followed by a call to condense().
441

442
443
444
445
446
447
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
448

449
450
451
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
452
453

        self.batch_update_builder.removed_append(req_index)
454
455
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
456
        self.spec_token_ids[req_index].clear()
457

458
459
460
461
462
463
464
465
466
467
468
469
470
471
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
            return req_index

472
473
474
475
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
476
        self.spec_decode_unsupported_reqs.discard(req_id)
477
478
479
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
480
481
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
482
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
483

484
485
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
486
            # False means we don't fill with -inf.
487
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
488
        self.bad_words_token_ids.pop(req_index, None)
489
490
        return req_index

491
492
493
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
494
495
496
497
498
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
499
500
501
502
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
503
        assert old_id_i1 is not None and old_id_i2 is not None
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens[i1], self.num_tokens[i2] = (
            self.num_tokens[i2],
            self.num_tokens[i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
524

525
526
527
528
529
530
531
532
533
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporiarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

534
535
536
537
538
539
540
541
542
543
544
545
546
547
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

548
        self.block_table.swap_row(i1, i2)
549

550
551
552
553
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
554

555
556
557
558
559
560
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
585
586
587
588

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

589
        if self.allowed_token_ids_mask_cpu_tensor is not None:
590
591
592
593
594
595
596
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
597

598
599
600
601
602
603
604
605
606
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
607
        """
608
609
        num_reqs = self.num_reqs

610
611
612
613
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
614
        if num_reqs == 0:
615
            # The batched states are empty.
616
617
            self._req_ids.clear()
            self.req_output_token_ids.clear()
618
            self.spec_token_ids.clear()
619
620
621
622
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
623
        last_req_index = num_reqs + len(empty_req_indices) - 1
624
625
626
627
628
629
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
630
631
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
632
633
634
            if empty_index >= last_req_index:
                break

635
636
637
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
638
639
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
640
            assert req_id is not None
641
642
643
644
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
645
646
            self.req_id_to_index[req_id] = empty_index

647
648
649
650
651
652
653
654
655
            if last_req_index != empty_index:
                (
                    self.spec_token_ids[last_req_index],
                    self.spec_token_ids[empty_index],
                ) = (
                    self.spec_token_ids[empty_index],
                    self.spec_token_ids[last_req_index],
                )
                self.spec_token_ids[last_req_index].clear()
656

657
658
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
659
660
                last_req_index, :num_tokens
            ]
661
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
662
663
                last_req_index, :num_tokens
            ]
664
            if last_req_index in self.req_prompt_embeds:
665
666
667
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
668
            self.num_tokens[empty_index] = num_tokens
669
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
670
671
672
673
674
675
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
676
            self.block_table.move_row(last_req_index, empty_index)
677
678

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
679
680
                last_req_index
            ]
681
682
683

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
684
                # Sampling state not used by pooling models.
685
686
687
688
689
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
690
691
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
692

693
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
694
695
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
696
697
698
699
700
701
702
703
704
705
706
707
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
708
709
710
711
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

712
            # TODO convert these to LogitsProcessors
713
            if self.allowed_token_ids_mask_cpu_tensor is not None:
714
715
716
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
717

718
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
719
720
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
721

722
723
724
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

725
        # Trim lists to the batch size.
726
727
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
728
        del self.spec_token_ids[num_reqs:]
729

730
    def refresh_metadata(self):
731
        """Apply any batch updates to sampling metadata."""
732

733
        if self.is_pooling_model:
734
735
736
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
                self.sampling_metadata = self._make_sampling_metadata()
737
738
739
740
741
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
742
743
744
745
746
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
            self.sampling_metadata = self._make_sampling_metadata()
747
748
749

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
750
        if not self.all_greedy:
751
752
753
            temperature = copy_slice(
                self.temperature_cpu_tensor, self.temperature, num_reqs
            )
754
755
        else:
            temperature = None
756
757
758
759
760
761
762
763
764
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
765
766
767
768
769
770
771
772
773
774
775
            copy_slice(
                self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs
            )
            copy_slice(
                self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs
            )
            copy_slice(
                self.repetition_penalties_cpu_tensor,
                self.repetition_penalties,
                num_reqs,
            )
776

777
778
        needs_prompt_token_ids = (
            not self.no_penalties
779
780
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
781
782
783
784
785
786
787
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
788

789
790
791
792
793
794
795
796
797
798
799
800
801
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

802
        allowed_token_ids_mask: torch.Tensor | None = None
803
804
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
805
806
807
808
809
            copy_slice(
                self.allowed_token_ids_mask_cpu_tensor,
                self.allowed_token_ids_mask,
                num_reqs,
            )
810
811
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

812
        return SamplingMetadata(
813
            temperature=temperature,
814
815
            all_greedy=self.all_greedy,
            all_random=self.all_random,
816
817
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
818
819
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
820
821
822
823
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
824
            output_token_ids=output_token_ids,
825
            spec_token_ids=cast(list[list[int]], self.spec_token_ids),
826
            no_penalties=self.no_penalties,
827
            allowed_token_ids_mask=allowed_token_ids_mask,
828
            bad_words_token_ids=self.bad_words_token_ids,
829
            logitsprocs=self.logitsprocs,
830
831
        )

832
833
834
835
836
837
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
838
839

        return PoolingMetadata(
840
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
841
842
843
844
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
        )

845
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
846
847
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
848
849
850
851
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
852
853
            pin_memory=self.pin_memory,
        )
854
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
855
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
856
857
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
858
        for i in range(num_reqs):
859
860
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
861

862
    def make_lora_inputs(
863
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
864
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
865
866
867
868
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
869
870
871
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
872
873
874
875
876
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

877
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
878
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
879
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
880

881
        active_lora_requests: set[LoRARequest] = set(
882
883
            self.lora_id_to_lora_request.values()
        )
884
885
886

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

887
888
889
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
890
        async_copy_ready_event: torch.Event,
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
                sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
            # Replace placeholder token id with actual sampled id.
            req_output_token_ids[-1] = sampled_token_ids[prev_index]

933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

953
954
    @property
    def no_penalties(self) -> bool:
955
956
957
958
959
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
960

961
    @property
962
    def max_num_logprobs(self) -> int | None:
963
        return max(self.num_logprobs.values()) if self.num_logprobs else None
964

965
966
967
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0