"vscode:/vscode.git/clone" did not exist on "0dd5dee9b9bc88453f5f3eacfde751e6b9ba4871"
gpu_input_batch.py 44.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4

5
from dataclasses import dataclass, field
6
from typing import Optional, cast
7
8
9
10

import numpy as np
import torch

11
12
from vllm import envs
from vllm.config import get_current_vllm_config
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams, SamplingType
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
from vllm.utils.collection_utils import swap_dict_values
19
from vllm.v1.outputs import LogprobsTensors
20
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
21
22
23
24
25
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
26
from vllm.v1.sample.metadata import SamplingMetadata
27
from vllm.v1.utils import copy_slice
28
from vllm.v1.worker.block_table import MultiGroupBlockTable
29
30
31
32
33


@dataclass
class CachedRequestState:
    req_id: str
34
    prompt_token_ids: list[int] | None
35
    mm_features: list[MultiModalFeatureSpec]
36
37
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
42

43
44
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
45

46
47
    xdrope_positions: torch.Tensor | None = None

48
49
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
50
51
52
53
54
55
    _prompt_token_ids_np: np.ndarray | None = field(
        default=None,
        init=False,
        repr=False,
        compare=False,
    )
56

57
58
59
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

60
61
62
63
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

64
    def __post_init__(self):
65
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
66
67
            self.prompt_token_ids, self.prompt_embeds
        )
68

69
70
71
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

72
73
    @property
    def num_tokens(self) -> int:
74
75
76
77
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
78
79
80
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
81
82
                    "provided via prompt_embeds, and its ID is unknown."
                )
83
            return self.prompt_token_ids[idx]
84
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
85
            return self.output_token_ids[idx - self.num_prompt_tokens]
86
        return -1
87
88
89
90


class InputBatch:
    def __init__(
91
92
93
94
95
96
97
98
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
99
        kernel_block_sizes: list[int],
100
        max_num_blocks_per_req: list[int] | None = None,
101
        logitsprocs: LogitsProcessors | None = None,
102
        logitsprocs_need_output_token_ids: bool = False,
103
        is_spec_decode: bool = False,
104
        is_pooling_model: bool = False,
105
        cp_kv_cache_interleave_size: int = 1,
106
    ):
107
108
109
110
111
        ori_max_num_reqs = max_num_reqs
        if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
            vllm_config = get_current_vllm_config()
            max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)

112
        self.is_pooling_model = is_pooling_model
113
        self.is_spec_decode = is_spec_decode
114
115
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
116
        self.max_num_batched_tokens = max_num_batched_tokens
117
118
        self.device = device
        self.pin_memory = pin_memory
119
        self.vocab_size = vocab_size
120

121
        self._req_ids: list[str | None] = []
122
        self.req_id_to_index: dict[str, int] = {}
123

124
125
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
126
127
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
128
        self.token_ids_cpu_tensor = torch.zeros(
129
            (ori_max_num_reqs, max_model_len),
130
131
            device="cpu",
            dtype=torch.int32,
132
            pin_memory=False,
133
134
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
135
        self.is_token_ids_tensor = torch.zeros(
136
137
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
138
        self.is_token_ids = self.is_token_ids_tensor.numpy()
139
140
141
142
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
143
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
144
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
145
        self.num_computed_tokens_cpu_tensor = torch.zeros(
146
            (max_num_reqs,),
147
148
149
150
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
151
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
152

153
        # Block table.
154
        self.block_table = MultiGroupBlockTable(
155
            max_num_reqs=max_num_reqs,
156
            max_model_len=max_model_len,
157
            max_num_batched_tokens=max_num_batched_tokens,
158
            pin_memory=pin_memory,
159
            device=device,
160
            block_sizes=block_sizes,
161
            kernel_block_sizes=kernel_block_sizes,
162
            max_num_blocks=max_num_blocks_per_req,
163
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
164
165
166
        )

        # Sampling-related.
167
168
169
170
171
172
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
173
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
174
175
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
176

177
178
179
180
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
181
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
182
        self.top_p_reqs: set[str] = set()
183

184
185
186
187
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
188
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
189
        self.top_k_reqs: set[str] = set()
190

191
        # Frequency penalty related data structures
192
193
194
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
195
        self.frequency_penalties_cpu_tensor = torch.empty(
196
197
198
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
199
        self.frequency_penalties_reqs: set[str] = set()
200
201

        # Presence penalty related data structures
202
203
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
204
        )
205
206
207
208
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
209
        self.presence_penalties_reqs: set[str] = set()
210
211

        # Repetition penalty related data structures
212
213
214
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
215
        self.repetition_penalties_cpu_tensor = torch.empty(
216
217
218
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
219
        self.repetition_penalties_reqs: set[str] = set()
220

221
        # Speculative decoding
222
223
224
225
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
226

227
        # lora related
228
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
229
230
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
231

232
        # req_index -> generator
233
234
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
235
        self.generators: dict[int, torch.Generator] = {}
236

237
        self.num_logprobs: dict[str, int] = {}
238

239
240
241
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

242
243
244
245
246
247
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
248
        self.has_allowed_token_ids: set[str] = set()
249
250
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
251
252
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
253

254
255
256
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

257
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
258

259
        self.req_output_token_ids: list[list[int] | None] = []
260

261
262
263
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
264
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
265

266
        # Store last speculative tokens for sampler.
267
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
268

269
270
271
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

272
        # for pooling models
273
        self.pooling_params: dict[str, PoolingParams] = {}
274
        self.pooling_states: dict[str, PoolingStates] = {}
275

276
        # Cached reference to the GPU tensor of previously sampled tokens
277
278
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
279
280
281
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
282
        self.sampled_token_ids_cpu: torch.Tensor | None = None
283
        self.async_copy_ready_event: torch.Event | None = None
284

285
    @property
286
    def req_ids(self) -> list[str]:
287
288
        # None elements should only be present transiently
        # while performing state updates to the batch.
289
        return cast(list[str], self._req_ids)
290

291
    def _register_add_request(self, request: "CachedRequestState") -> int:
292
293
294
295
296
297
298
299
300
301
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
302
303
304
305
306
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
307
308
309
310
311
312
313
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
314

315
        return new_req_index
316

317
318
319
    def add_request(
        self,
        request: "CachedRequestState",
320
    ) -> int:
321
        req_index = self._register_add_request(request)
322
323

        req_id = request.req_id
324
325
326
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
327
            self.spec_token_ids.append([])
328
329
330
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
331
            self.spec_token_ids[req_index].clear()
332

333
334
335
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
336
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
337
338
            request.prompt_token_ids, request.prompt_embeds
        )
339
        self.num_prompt_tokens[req_index] = num_prompt_tokens
340
        start_idx = num_prompt_tokens
341
        if request.prompt_token_ids is not None:
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
            if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
                self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
            else:
                prompt_token_ids_np = request._prompt_token_ids_np
                rebuild_prompt_cache = True
                if prompt_token_ids_np is not None:
                    rebuild_prompt_cache = (
                        prompt_token_ids_np.dtype != np.int32
                        or prompt_token_ids_np.size != num_prompt_tokens
                    )
                if rebuild_prompt_cache:
                    prompt_token_ids_np = np.asarray(request.prompt_token_ids, dtype=np.int32)
                    request._prompt_token_ids_np = prompt_token_ids_np
                np.copyto(
                    self.token_ids_cpu[req_index, :num_prompt_tokens],
                    prompt_token_ids_np,
                    casting="no",
                )
360
361
362
363
364
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
365
366
367
368
369
370
371
372
373
374
375
        if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
            end_idx = start_idx + len(request.output_token_ids)
            self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
        else:
            output_token_ids_np = np.asarray(request.output_token_ids, dtype=np.int32)
            end_idx = start_idx + output_token_ids_np.size
            np.copyto(
                self.token_ids_cpu[req_index, start_idx:end_idx],
                output_token_ids_np,
                casting="no",
            )
376
        self.is_token_ids[req_index, start_idx:end_idx] = True
377
378
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
379
380

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
381
        self.block_table.add_row(request.block_ids, req_index)
382

383
384
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
385
386
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
401
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
402
403
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
404
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
405
406
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
407
408
409
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
410
411
412
413
414
415
416
417
418
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
419
420
421
422
423
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
424
425
426
427
428
429
430
431
432
433

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
434
435
                        device=self.device,
                    )
436
437
438
439
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
440
441
                        device="cpu",
                    )
442
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
443
                # False means we don't fill with -inf.
444
                self.allowed_token_ids_mask_cpu_tensor[req_index][
445
446
                    sampling_params.allowed_token_ids
                ] = False
447

448
            if sampling_params.bad_words_token_ids:
449
450
451
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
452
        elif pooling_params := request.pooling_params:
453
454
455
            pooling_states = request.pooling_states
            assert pooling_states is not None

456
            self.pooling_params[req_id] = pooling_params
457
            self.pooling_states[req_id] = pooling_states
458
            self.logits_processing_needs_token_ids[req_index] = (
459
460
                pooling_params.requires_token_ids
            )
461
        else:
462
            raise NotImplementedError("Unrecognized request type")
463

464
465
466
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

467
468
469
470
471
472
473
474
475
476
477
478
479
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

480
481
        return req_index

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

508
    def remove_request(self, req_id: str) -> int | None:
509
        """This method must always be followed by a call to condense().
510

511
512
513
514
515
516
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
517

518
519
520
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
521
522

        self.batch_update_builder.removed_append(req_index)
523
524
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
525
        self.spec_token_ids[req_index].clear()
526

527
528
529
530
531
532
533
534
535
536
537
538
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
539
            self.pooling_states.pop(req_id, None)
540
541
            return req_index

542
543
544
545
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
546
547
548
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
549
550
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
551
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
552
553
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
554

555
556
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
557
            # False means we don't fill with -inf.
558
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
559
        self.bad_words_token_ids.pop(req_index, None)
560
561
        return req_index

562
563
564
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
565
566
567
568
569
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
570
571
572
573
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
574
        assert old_id_i1 is not None and old_id_i2 is not None
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
591

592
593
594
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
595
        # instead, we need to temporarily copy the data for one of the indices
596
597
598
599
600
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

601
602
603
604
605
606
607
608
609
610
611
612
613
614
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

615
        self.block_table.swap_row(i1, i2)
616

617
618
619
620
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
621

622
623
624
625
626
627
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
652
653
654
655

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

656
        if self.allowed_token_ids_mask_cpu_tensor is not None:
657
658
659
660
661
662
663
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
664

665
666
667
668
669
670
671
672
673
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
674
        """
675
676
        num_reqs = self.num_reqs

677
678
679
680
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
681
        if num_reqs == 0:
682
            # The batched states are empty.
683
684
            self._req_ids.clear()
            self.req_output_token_ids.clear()
685
            self.spec_token_ids.clear()
686
687
688
689
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
690
        last_req_index = num_reqs + len(empty_req_indices) - 1
691
692
693
694
695
696
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
697
698
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
699
700
701
            if empty_index >= last_req_index:
                break

702
703
704
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
705
706
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
707
            assert req_id is not None
708
709
710
711
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
712
713
            self.req_id_to_index[req_id] = empty_index

714
715
716
717
718
719
720
721
722
            num_tokens = self.num_tokens_no_spec[last_req_index] + len(
                self.spec_token_ids[last_req_index]
            )

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
723

724
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
725
726
                last_req_index, :num_tokens
            ]
727
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
728
729
                last_req_index, :num_tokens
            ]
730
            if last_req_index in self.req_prompt_embeds:
731
732
733
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
734
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
735
736
737
738
739
740
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
741
            self.block_table.move_row(last_req_index, empty_index)
742
743

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
744
745
                last_req_index
            ]
746
747
748

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
749
                # Sampling state not used by pooling models.
750
751
752
753
754
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
755
756
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
757

758
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
759
760
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
761
762
763
764
765
766
767
768
769
770
771
772
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
773
774
775
776
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

777
            # TODO convert these to LogitsProcessors
778
            if self.allowed_token_ids_mask_cpu_tensor is not None:
779
780
781
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
782

783
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
784
785
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
786

787
788
789
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

790
        # Trim lists to the batch size.
791
792
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
793
        del self.spec_token_ids[num_reqs:]
794

795
    def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
796
        """Apply any batch updates to sampling metadata."""
797

798
        if self.is_pooling_model:
799
800
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
801
                self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
802
803
804
805
806
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
807
808
809
810
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
        if batch_update:
811
            self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
812

813
    def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
814
        num_reqs = self.num_reqs
815
        if not self.all_greedy:
816
            temperature = copy_slice(
817
818
                self.temperature_cpu_tensor, self.temperature, 
                num_reqs, repeat_counts
819
            )
820
821
        else:
            temperature = None
822

823
        if not self.no_top_p:
824
            top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
825
        if not self.no_top_k:
826
            top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
827
828
829
830
831

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
832
833
834
835
836
837
838
839
840
            frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs,
                       repeat_counts)
            presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs,
                       repeat_counts)
            repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs,
                       repeat_counts)
841

842
843
        needs_prompt_token_ids = (
            not self.no_penalties
844
845
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
846
847
848
849
850
851
852
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
853

854
855
856
857
858
859
860
861
862
863
864
865
866
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

867
        allowed_token_ids_mask: torch.Tensor | None = None
868
869
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
870
871
872
            allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs,
                       repeat_counts)
873

874
        return SamplingMetadata(
875
            temperature=temperature,
876
877
            all_greedy=self.all_greedy,
            all_random=self.all_random,
878
879
            top_p=None if self.no_top_p else top_p,
            top_k=None if self.no_top_k else top_k,
880
881
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
882
            prompt_token_ids=prompt_token_ids,
883
884
885
            frequency_penalties=None if self.no_penalties else frequency_penalties,
            presence_penalties=None if self.no_penalties else presence_penalties,
            repetition_penalties=None if self.no_penalties else repetition_penalties,
886
            output_token_ids=output_token_ids,
887
            spec_token_ids=self.spec_token_ids,
888
            no_penalties=self.no_penalties,
889
            allowed_token_ids_mask=allowed_token_ids_mask,
890
            bad_words_token_ids=self.bad_words_token_ids,
891
            logitsprocs=self.logitsprocs,
892
893
        )

894
895
896
897
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

898
899
900
901
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

902
903
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
904
        pooling_states = self.get_pooling_states()
905
906

        return PoolingMetadata(
907
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
908
909
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
910
            pooling_states=pooling_states,
911
912
        )

913
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
914
915
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
916
917
918
919
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
920
921
            pin_memory=self.pin_memory,
        )
922
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
923
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
924
925
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
926
        for i in range(num_reqs):
927
928
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
929

930
    def make_lora_inputs(
931
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
932
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
933
934
935
936
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
937
938
939
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
940
941
942
943
944
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

945
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
946
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
947
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
948

949
        active_lora_requests: set[LoRARequest] = set(
950
951
            self.lora_id_to_lora_request.values()
        )
952
953
954

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

955
956
957
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
958
        async_copy_ready_event: torch.Event,
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
            # output placeholders (tokens can be discarded after a kv-load failure).
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
            end_index = first_placeholder + num_to_replace
            req_output_token_ids[first_placeholder:end_index] = new_ids

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
1031

1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1052
1053
    @property
    def no_penalties(self) -> bool:
1054
1055
1056
1057
1058
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1059

1060
    @property
1061
    def max_num_logprobs(self) -> int | None:
1062
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1063

1064
1065
1066
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0