gpu_input_batch.py 44.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4

5
from dataclasses import dataclass, field
6
from typing import Optional, cast
7
8
9
10

import numpy as np
import torch

11
12
from vllm import envs
from vllm.config import get_current_vllm_config
13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams, SamplingType
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
from vllm.utils.collection_utils import swap_dict_values
19
from vllm.v1.outputs import LogprobsTensors
20
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
21
22
23
24
25
from vllm.v1.sample.logits_processor import (
    BatchUpdateBuilder,
    LogitsProcessors,
    MoveDirectionality,
)
26
from vllm.v1.sample.metadata import SamplingMetadata
27
from vllm.v1.utils import copy_slice
28
from vllm.v1.worker.block_table import MultiGroupBlockTable
29
30
31
32
33


@dataclass
class CachedRequestState:
    req_id: str
34
    prompt_token_ids: list[int] | None
35
    mm_features: list[MultiModalFeatureSpec]
36
37
    sampling_params: SamplingParams | None
    generator: torch.Generator | None
38

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
41
    output_token_ids: list[int]
42

43
44
    mrope_positions: torch.Tensor | None = None
    mrope_position_delta: int | None = None
45

46
47
    xdrope_positions: torch.Tensor | None = None

48
49
    lora_request: LoRARequest | None = None
    prompt_embeds: torch.Tensor | None = None
50
51
52
53
54
55
    _prompt_token_ids_np: np.ndarray | None = field(
        default=None,
        init=False,
        repr=False,
        compare=False,
    )
56

57
58
59
    # Used when both async_scheduling and spec_decode are enabled.
    prev_num_draft_len: int = 0

60
61
62
63
    # for pooling models
    pooling_params: PoolingParams | None = None
    pooling_states: PoolingStates | None = None

64
    def __post_init__(self):
65
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
66
67
            self.prompt_token_ids, self.prompt_embeds
        )
68

69
70
71
        if self.pooling_params is not None:
            self.pooling_states = PoolingStates()

72
73
    @property
    def num_tokens(self) -> int:
74
75
76
77
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
78
79
80
            if self.prompt_token_ids is None:
                raise ValueError(
                    f"Tried to access token index {idx}, but that token was "
81
82
                    "provided via prompt_embeds, and its ID is unknown."
                )
83
            return self.prompt_token_ids[idx]
84
        if idx - self.num_prompt_tokens < len(self.output_token_ids):
85
            return self.output_token_ids[idx - self.num_prompt_tokens]
86
        return -1
87
88
89
90


class InputBatch:
    def __init__(
91
92
93
94
95
96
97
98
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
99
        kernel_block_sizes: list[int],
100
        max_num_blocks_per_req: list[int] | None = None,
101
        logitsprocs: LogitsProcessors | None = None,
102
        logitsprocs_need_output_token_ids: bool = False,
103
        is_spec_decode: bool = False,
104
        is_pooling_model: bool = False,
105
        cp_kv_cache_interleave_size: int = 1,
106
    ):
107
108
109
110
111
        ori_max_num_reqs = max_num_reqs
        if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
            vllm_config = get_current_vllm_config()
            max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)

112
        self.is_pooling_model = is_pooling_model
113
        self.is_spec_decode = is_spec_decode
114
115
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
116
        self.max_num_batched_tokens = max_num_batched_tokens
117
118
        self.device = device
        self.pin_memory = pin_memory
119
        self.vocab_size = vocab_size
120

121
        self._req_ids: list[str | None] = []
122
        self.req_id_to_index: dict[str, int] = {}
yangshj1's avatar
yangshj1 committed
123
        self.invalid_req_indices: list[int] = []
124

125
126
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
127
128
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
129
        self.token_ids_cpu_tensor = torch.zeros(
130
            (ori_max_num_reqs, max_model_len),
131
132
            device="cpu",
            dtype=torch.int32,
133
            pin_memory=False,
134
135
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
136
        self.is_token_ids_tensor = torch.zeros(
137
138
            (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
        )
139
        self.is_token_ids = self.is_token_ids_tensor.numpy()
140
141
142
143
        # Store prompt embeddings per request to avoid OOM from large upfront
        # allocation if max_model_len is big.
        # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
        self.req_prompt_embeds: dict[int, torch.Tensor] = {}
144
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
145
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
146
        self.num_computed_tokens_cpu_tensor = torch.zeros(
147
            (max_num_reqs,),
148
149
150
151
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
152
        self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
153

154
        # Block table.
155
        self.block_table = MultiGroupBlockTable(
156
            max_num_reqs=max_num_reqs,
157
            max_model_len=max_model_len,
158
            max_num_batched_tokens=max_num_batched_tokens,
159
            pin_memory=pin_memory,
160
            device=device,
161
            block_sizes=block_sizes,
162
            kernel_block_sizes=kernel_block_sizes,
163
            max_num_blocks=max_num_blocks_per_req,
164
            cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
165
166
167
        )

        # Sampling-related.
168
169
170
171
172
173
        self.temperature = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device=device
        )
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
174
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
175
176
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
177

178
179
180
181
        self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.top_p_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
        )
182
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
183
        self.top_p_reqs: set[str] = set()
184

185
186
187
188
        self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
        self.top_k_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
189
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
190
        self.top_k_reqs: set[str] = set()
191

192
        # Frequency penalty related data structures
193
194
195
        self.frequency_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
196
        self.frequency_penalties_cpu_tensor = torch.empty(
197
198
199
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
200
        self.frequency_penalties_reqs: set[str] = set()
201
202

        # Presence penalty related data structures
203
204
        self.presence_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
205
        )
206
207
208
209
        self.presence_penalties_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
210
        self.presence_penalties_reqs: set[str] = set()
211
212

        # Repetition penalty related data structures
213
214
215
        self.repetition_penalties = torch.empty(
            (max_num_reqs,), dtype=torch.float, device=device
        )
216
        self.repetition_penalties_cpu_tensor = torch.empty(
217
218
219
            (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
        )
        self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
220
        self.repetition_penalties_reqs: set[str] = set()
221

222
        # Speculative decoding
223
224
225
226
        self.num_accepted_tokens_cpu_tensor = torch.ones(
            (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
        )
        self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
227

228
        # lora related
229
        self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
230
231
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
232

233
        # req_index -> generator
234
235
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
236
        self.generators: dict[int, torch.Generator] = {}
237

238
        self.num_logprobs: dict[str, int] = {}
239

240
241
242
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

243
244
245
246
247
248
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # TODO convert this to LogitsProcessor
249
        self.has_allowed_token_ids: set[str] = set()
250
251
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
252
253
        self.allowed_token_ids_mask: torch.Tensor | None = None
        self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
254

255
256
257
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

258
        self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
259

260
        self.req_output_token_ids: list[list[int] | None] = []
261

262
263
264
        # Store provided logitsprocs. If none are provided, initialize empty
        # data structure
        self.logitsprocs = logitsprocs or LogitsProcessors()
265
        self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
266

267
        # Store last speculative tokens for sampler.
268
        self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
269

270
271
272
        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

273
        # for pooling models
274
        self.pooling_params: dict[str, PoolingParams] = {}
275
        self.pooling_states: dict[str, PoolingStates] = {}
276

277
        # Cached reference to the GPU tensor of previously sampled tokens
278
279
        self.prev_sampled_token_ids: torch.Tensor | None = None
        self.prev_req_id_to_index: dict[str, int] | None = None
280
281
282
        # These are used to update output_token_ids with real sampled
        # ids from prior step, if required by current sampling params
        # (e.g. penalties).
283
        self.sampled_token_ids_cpu: torch.Tensor | None = None
284
        self.async_copy_ready_event: torch.Event | None = None
285

286
    @property
287
    def req_ids(self) -> list[str]:
288
289
        # None elements should only be present transiently
        # while performing state updates to the batch.
290
        return cast(list[str], self._req_ids)
291

292
    def _register_add_request(self, request: "CachedRequestState") -> int:
293
294
295
296
297
298
299
300
301
302
        """Track add-request operations for logits processors.
        Not applicable to pooling models.
        """

        # Fill the next empty index if there is one.
        if (new_req_index := self.batch_update_builder.pop_removed()) is None:
            # Append to end otherwise.
            new_req_index = self.num_reqs

        assert new_req_index < self.max_num_reqs
303
304
305
306
307
        self.batch_update_builder.batch_changed = True
        if request.sampling_params:
            # Detailed added request metadata is only required for non-pooling
            # models, to support logitsprocs.
            self.batch_update_builder.added.append(
308
309
310
311
312
313
314
                (
                    new_req_index,
                    request.sampling_params,
                    request.prompt_token_ids,
                    request.output_token_ids,
                )
            )
315

316
        return new_req_index
317

318
319
320
    def add_request(
        self,
        request: "CachedRequestState",
321
    ) -> int:
322
        req_index = self._register_add_request(request)
323
324

        req_id = request.req_id
325
326
327
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
328
            self.spec_token_ids.append([])
329
330
331
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids
332
            self.spec_token_ids[req_index].clear()
333

334
335
336
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
337
        num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
338
339
            request.prompt_token_ids, request.prompt_embeds
        )
340
        self.num_prompt_tokens[req_index] = num_prompt_tokens
341
        start_idx = num_prompt_tokens
342
        if request.prompt_token_ids is not None:
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
                self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
            else:
                prompt_token_ids_np = request._prompt_token_ids_np
                rebuild_prompt_cache = True
                if prompt_token_ids_np is not None:
                    rebuild_prompt_cache = (
                        prompt_token_ids_np.dtype != np.int32
                        or prompt_token_ids_np.size != num_prompt_tokens
                    )
                if rebuild_prompt_cache:
                    prompt_token_ids_np = np.asarray(request.prompt_token_ids, dtype=np.int32)
                    request._prompt_token_ids_np = prompt_token_ids_np
                np.copyto(
                    self.token_ids_cpu[req_index, :num_prompt_tokens],
                    prompt_token_ids_np,
                    casting="no",
                )
361
362
363
364
365
            self.is_token_ids[req_index, :num_prompt_tokens] = True
        else:
            self.is_token_ids[req_index, :num_prompt_tokens] = False
        if request.prompt_embeds is not None:
            self.req_prompt_embeds[req_index] = request.prompt_embeds
366
367
368
369
370
371
372
373
374
375
376
        if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
            end_idx = start_idx + len(request.output_token_ids)
            self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
        else:
            output_token_ids_np = np.asarray(request.output_token_ids, dtype=np.int32)
            end_idx = start_idx + output_token_ids_np.size
            np.copyto(
                self.token_ids_cpu[req_index, start_idx:end_idx],
                output_token_ids_np,
                casting="no",
            )
377
        self.is_token_ids[req_index, start_idx:end_idx] = True
378
379
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
380
381

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
382
        self.block_table.add_row(request.block_ids, req_index)
383

384
385
        if sampling_params := request.sampling_params:
            if sampling_params.sampling_type == SamplingType.GREEDY:
386
387
                # Should avoid division by zero later when apply_temperature.
                self.temperature_cpu[req_index] = 0.0
388
389
390
391
392
393
394
395
396
397
398
399
400
401
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
402
            self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty
403
404
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
405
            self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty
406
407
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
408
409
410
            self.repetition_penalties_cpu[req_index] = (
                sampling_params.repetition_penalty
            )
411
412
413
414
415
416
417
418
419
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
420
421
422
423
424
                self.num_logprobs[req_id] = (
                    self.vocab_size
                    if sampling_params.logprobs == -1
                    else sampling_params.logprobs
                )
425
426
427
428
429
430
431
432
433
434

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
435
436
                        device=self.device,
                    )
437
438
439
440
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
441
442
                        device="cpu",
                    )
443
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
444
                # False means we don't fill with -inf.
445
                self.allowed_token_ids_mask_cpu_tensor[req_index][
446
447
                    sampling_params.allowed_token_ids
                ] = False
448

449
            if sampling_params.bad_words_token_ids:
450
451
452
                self.bad_words_token_ids[req_index] = (
                    sampling_params.bad_words_token_ids
                )
453
        elif pooling_params := request.pooling_params:
454
455
456
            pooling_states = request.pooling_states
            assert pooling_states is not None

457
            self.pooling_params[req_id] = pooling_params
458
            self.pooling_states[req_id] = pooling_states
459
            self.logits_processing_needs_token_ids[req_index] = (
460
461
                pooling_params.requires_token_ids
            )
462
        else:
463
            raise NotImplementedError("Unrecognized request type")
464

465
466
467
        # Speculative decoding: by default 1 token is generated.
        self.num_accepted_tokens_cpu[req_index] = 1

468
469
470
471
472
473
474
475
476
477
478
479
480
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

481
482
        return req_index

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    def update_req_spec_token_ids(
        self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
    ) -> None:
        req_id = request.req_id
        req_index = self.req_id_to_index[req_id]
        cur_spec_token_ids = self.spec_token_ids[req_index]
        # When speculative decoding is used with structured output,
        # the scheduler can drop draft tokens that do not
        # conform to the schema. This can result in
        # scheduler_output.scheduled_spec_decode_tokens being empty,
        # even when speculative decoding is enabled.
        cur_spec_token_ids.clear()
        spec_token_ids = scheduled_spec_tokens.get(req_id, ())
        num_spec_tokens = len(spec_token_ids)
        request.prev_num_draft_len = num_spec_tokens
        if not spec_token_ids:
            return

        # For async scheduling, token_ids_cpu assigned from
        # spec_token_ids are placeholders and will be overwritten in
        # _prepare_input_ids.
        start_index = self.num_tokens_no_spec[req_index]
        end_token_index = start_index + num_spec_tokens
        self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
        cur_spec_token_ids.extend(spec_token_ids)

509
    def remove_request(self, req_id: str) -> int | None:
510
        """This method must always be followed by a call to condense().
511

512
513
514
515
516
517
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
518

519
520
521
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
522
523

        self.batch_update_builder.removed_append(req_index)
524
525
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
526
        self.spec_token_ids[req_index].clear()
527

528
529
530
531
532
533
534
535
536
537
538
539
        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            lora_req_ids = self.lora_id_to_request_ids[lora_id]
            lora_req_ids.discard(req_id)
            if not lora_req_ids:
                del self.lora_id_to_request_ids[lora_id]
                del self.lora_id_to_lora_request[lora_id]
            self.request_lora_mapping[req_index] = 0

        if self.is_pooling_model:
            self.pooling_params.pop(req_id, None)
540
            self.pooling_states.pop(req_id, None)
541
542
            return req_index

543
544
545
546
        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
547
548
549
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
550
551
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
552
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
553
554
        if self.prev_req_id_to_index is not None:
            self.prev_req_id_to_index.pop(req_id, None)
555

556
557
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
558
            # False means we don't fill with -inf.
559
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
560
        self.bad_words_token_ids.pop(req_index, None)
561
562
        return req_index

563
564
565
    def swap_states(self, i1: int, i2: int) -> None:
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
566
567
568
569
570
        self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1]  # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
            self.req_output_token_ids[i2],
            self.req_output_token_ids[i1],
        )
571
572
573
574
        self.spec_token_ids[i1], self.spec_token_ids[i2] = (
            self.spec_token_ids[i2],
            self.spec_token_ids[i1],
        )
575
        assert old_id_i1 is not None and old_id_i2 is not None
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
            self.req_id_to_index[old_id_i2],
            self.req_id_to_index[old_id_i1],
        )
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
            self.num_tokens_no_spec[i2],
            self.num_tokens_no_spec[i1],
        )
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = (
            self.num_prompt_tokens[i2],
            self.num_prompt_tokens[i1],
        )
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = (
            self.num_computed_tokens_cpu[i2],
            self.num_computed_tokens_cpu[i1],
        )
592

593
594
595
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
596
        # instead, we need to temporarily copy the data for one of the indices
597
598
599
600
601
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

602
603
604
605
606
607
608
609
610
611
612
613
614
615
        self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

        # Swap prompt embeddings if they exist
        embeds_i1 = self.req_prompt_embeds.get(i1)
        embeds_i2 = self.req_prompt_embeds.get(i2)
        if embeds_i1 is not None:
            self.req_prompt_embeds[i2] = embeds_i1
        else:
            self.req_prompt_embeds.pop(i2, None)
        if embeds_i2 is not None:
            self.req_prompt_embeds[i1] = embeds_i2
        else:
            self.req_prompt_embeds.pop(i1, None)

616
        self.block_table.swap_row(i1, i2)
617

618
619
620
621
        self.request_lora_mapping[i1], self.request_lora_mapping[i2] = (
            self.request_lora_mapping[i2],
            self.request_lora_mapping[i1],
        )
622

623
624
625
626
627
628
        if self.is_pooling_model:
            # Sampling and logits parameters don't apply to pooling models.
            return

        # For autoregressive models, track detailed request reordering info
        # to support logitsprocs.
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP))

        self.temperature_cpu[i1], self.temperature_cpu[i2] = (
            self.temperature_cpu[i2],
            self.temperature_cpu[i1],
        )
        self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = (
            self.frequency_penalties_cpu[i2],
            self.frequency_penalties_cpu[i1],
        )
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = (
            self.presence_penalties_cpu[i2],
            self.presence_penalties_cpu[i1],
        )
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = (
            self.repetition_penalties_cpu[i2],
            self.repetition_penalties_cpu[i1],
        )
        self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = (
            self.num_accepted_tokens_cpu[i2],
            self.num_accepted_tokens_cpu[i1],
        )
653
654
655
656

        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)

657
        if self.allowed_token_ids_mask_cpu_tensor is not None:
658
659
660
661
662
663
664
            (
                self.allowed_token_ids_mask_cpu_tensor[i1],
                self.allowed_token_ids_mask_cpu_tensor[i2],
            ) = (
                self.allowed_token_ids_mask_cpu_tensor[i2],
                self.allowed_token_ids_mask_cpu_tensor[i1],
            )
665

666
667
668
669
670
671
672
673
674
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
675
        """
676
677
        num_reqs = self.num_reqs

678
679
680
681
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
682
        if num_reqs == 0:
683
            # The batched states are empty.
684
685
            self._req_ids.clear()
            self.req_output_token_ids.clear()
686
            self.spec_token_ids.clear()
687
688
689
690
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
691
        last_req_index = num_reqs + len(empty_req_indices) - 1
692
693
694
695
696
697
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
698
699
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
700
701
702
            if empty_index >= last_req_index:
                break

703
704
705
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
706
707
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
708
            assert req_id is not None
709
710
711
712
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
713
714
            self.req_id_to_index[req_id] = empty_index

715
716
717
718
719
720
721
722
723
            num_tokens = self.num_tokens_no_spec[last_req_index] + len(
                self.spec_token_ids[last_req_index]
            )

            (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
                self.spec_token_ids[empty_index],
                self.spec_token_ids[last_req_index],
            )
            self.spec_token_ids[last_req_index].clear()
724

725
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
726
727
                last_req_index, :num_tokens
            ]
728
            self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
729
730
                last_req_index, :num_tokens
            ]
731
            if last_req_index in self.req_prompt_embeds:
732
733
734
                self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
                    last_req_index
                )
735
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
736
737
738
739
740
741
                last_req_index
            ]
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
            self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
                last_req_index
            ]
742
            self.block_table.move_row(last_req_index, empty_index)
743
744

            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
745
746
                last_req_index
            ]
747
748
749

            if self.is_pooling_model:
                last_req_index -= 1
co63oc's avatar
co63oc committed
750
                # Sampling state not used by pooling models.
751
752
753
754
755
                continue

            # Autoregressive models require detailed tracking of condense
            # operations to support logitsprocs
            self.batch_update_builder.moved.append(
756
757
                (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
            )
758

759
            self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index]
760
761
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
762
763
764
765
766
767
768
769
770
771
772
773
            self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[
                last_req_index
            ]
            self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[
                last_req_index
            ]
            self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[
                last_req_index
            ]
            self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[
                last_req_index
            ]
774
775
776
777
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

778
            # TODO convert these to LogitsProcessors
779
            if self.allowed_token_ids_mask_cpu_tensor is not None:
780
781
782
                self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                    self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                )
783

784
            bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
785
786
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
787

788
789
790
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

791
        # Trim lists to the batch size.
792
793
        del self._req_ids[num_reqs:]
        del self.req_output_token_ids[num_reqs:]
794
        del self.spec_token_ids[num_reqs:]
795

796
    def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
797
        """Apply any batch updates to sampling metadata."""
798

799
        if self.is_pooling_model:
800
801
            batch_changed = self.batch_update_builder.reset()
            if batch_changed:
802
                self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
803
804
805
806
807
            return

        # For non-pooling models - generate and apply logitsprocs update;
        # reset batch update tracking.
        # Update sampling metadata if batch state is changed.
808
809
810
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
811
        if batch_update or repeat_counts is not None:
812
            self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
813

814
    def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
815
        num_reqs = self.num_reqs
816
817
818
819
820
821
822
823
824
825
        # Host-side summaries for reduced top-k/top-p sampling.
        # Compute before copy_slice(top_k), which may rewrite top_k_cpu_tensor
        # when repeat_counts is provided.
        max_top_k = None
        has_any_no_top_k = False
        if not self.no_top_k and num_reqs > 0:
            top_k_cpu = self.top_k_cpu[:num_reqs]
            max_top_k = int(top_k_cpu.max())
            has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())

826
        if not self.all_greedy:
827
            temperature = copy_slice(
828
829
                self.temperature_cpu_tensor, self.temperature, 
                num_reqs, repeat_counts
830
            )
831
832
        else:
            temperature = None
833

834
        if not self.no_top_p:
835
            top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
836
        if not self.no_top_k:
837
            top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
838
839
840
841
842

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
843
844
845
846
847
848
849
850
851
            frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs,
                       repeat_counts)
            presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs,
                       repeat_counts)
            repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs,
                       repeat_counts)
852

853
854
        needs_prompt_token_ids = (
            not self.no_penalties
855
856
            or self.logits_processing_needs_token_ids[:num_reqs].any()
        )
857
858
859
860
861
862
863
        # The prompt tokens are used only for applying penalties or
        # step pooling during the sampling/pooling process.
        # Hence copy these tensors only when there are requests which
        # need penalties/step_pooler to be applied.
        prompt_token_ids = (
            self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
        )
864

865
866
867
868
869
870
871
872
873
874
875
876
877
        # Only set output_token_ids if required by the current requests'
        # sampling parameters.
        needs_output_token_ids = (
            not self.no_penalties
            or bool(self.bad_words_token_ids)
            or self.logitsprocs_need_output_token_ids
        )
        output_token_ids = (
            cast(list[list[int]], self.req_output_token_ids)
            if needs_output_token_ids
            else []
        )

878
        allowed_token_ids_mask: torch.Tensor | None = None
879
880
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
881
882
883
            allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs,
                       repeat_counts)
884

885
        return SamplingMetadata(
886
            temperature=temperature,
887
888
            all_greedy=self.all_greedy,
            all_random=self.all_random,
889
890
            top_p=None if self.no_top_p else top_p,
            top_k=None if self.no_top_k else top_k,
891
892
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
893
            prompt_token_ids=prompt_token_ids,
894
895
896
            frequency_penalties=None if self.no_penalties else frequency_penalties,
            presence_penalties=None if self.no_penalties else presence_penalties,
            repetition_penalties=None if self.no_penalties else repetition_penalties,
897
            output_token_ids=output_token_ids,
898
            spec_token_ids=self.spec_token_ids,
899
            no_penalties=self.no_penalties,
900
            allowed_token_ids_mask=allowed_token_ids_mask,
901
            bad_words_token_ids=self.bad_words_token_ids,
902
            logitsprocs=self.logitsprocs,
903
904
            max_top_k=max_top_k,
            has_any_no_top_k=has_any_no_top_k,
905
906
        )

907
908
909
910
    def get_pooling_params(self) -> list[PoolingParams]:
        assert len(self.req_ids) == len(self.pooling_params)
        return [self.pooling_params[req_id] for req_id in self.req_ids]

911
912
913
914
    def get_pooling_states(self) -> list[PoolingStates]:
        assert len(self.req_ids) == len(self.pooling_states)
        return [self.pooling_states[req_id] for req_id in self.req_ids]

915
916
    def get_pooling_metadata(self) -> PoolingMetadata:
        pooling_params = self.get_pooling_params()
917
        pooling_states = self.get_pooling_states()
918
919

        return PoolingMetadata(
920
            prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
921
922
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
923
            pooling_states=pooling_states,
924
925
        )

926
    def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
927
928
        num_reqs = self.num_reqs
        max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
929
930
931
932
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
933
934
            pin_memory=self.pin_memory,
        )
935
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
936
        prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
937
938
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
939
        for i in range(num_reqs):
940
941
            prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
        return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
942

943
    def make_lora_inputs(
944
        self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
945
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
946
947
948
949
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
950
951
952
            1. prompt_lora_mapping: A tuple of size np.sum(num_sampled_tokens)
               where, prompt_lora_mapping[i] is the LoRA id to use for the ith
               sampled token.
953
954
955
956
957
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

958
        req_lora_mapping = self.request_lora_mapping[: self.num_reqs]
959
        prompt_lora_mapping = tuple(req_lora_mapping.repeat(num_sampled_tokens))
960
        token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens))
961

962
        active_lora_requests: set[LoRARequest] = set(
963
964
            self.lora_id_to_lora_request.values()
        )
965
966
967

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

968
969
970
    def set_async_sampled_token_ids(
        self,
        sampled_token_ids_cpu: torch.Tensor,
971
        async_copy_ready_event: torch.Event,
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
    ) -> None:
        """
        In async scheduling case, store ref to sampled_token_ids_cpu
        tensor and corresponding copy-ready event. Used to repair
        output_token_ids prior to sampling, if needed by logits processors.
        """
        if self.sampling_metadata.output_token_ids:
            self.sampled_token_ids_cpu = sampled_token_ids_cpu
            self.async_copy_ready_event = async_copy_ready_event
        else:
            self.sampled_token_ids_cpu = None
            self.async_copy_ready_event = None

    def update_async_output_token_ids(self) -> None:
        """
        In async scheduling case, update output_token_ids in sampling metadata
        from prior steps sampled token ids once they've finished copying to CPU.
        This is called right before they are needed by the logits processors.
        """
        output_token_ids = self.sampling_metadata.output_token_ids
        if self.sampled_token_ids_cpu is None or not output_token_ids:
            # Output token ids not needed or not async scheduling.
            return

        assert self.prev_req_id_to_index is not None
        sampled_token_ids = None
        for index, req_id in enumerate(self.req_ids):
            prev_index = self.prev_req_id_to_index.get(req_id)
            if prev_index is None:
                continue
            req_output_token_ids = output_token_ids[index]
            if not req_output_token_ids or req_output_token_ids[-1] != -1:
                # Final output id is not a placeholder, some tokens must have
                # been discarded after a kv-load failure.
                continue
            if sampled_token_ids is None:
                assert self.async_copy_ready_event is not None
                self.async_copy_ready_event.synchronize()
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
                sampled_token_ids = self.sampled_token_ids_cpu.tolist()
            # Replace placeholder token id(s) with actual sampled id(s).
            new_ids: list[int] = sampled_token_ids[prev_index]
            if not new_ids:
                continue
            num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
            # Also account for case where there may be a smaller number of
            # output placeholders (tokens can be discarded after a kv-load failure).
            first_placeholder = req_output_token_ids.index(-1)
            num_placeholders = len(req_output_token_ids) - first_placeholder
            num_to_replace = min(num_sampled_ids, num_placeholders)
            del new_ids[num_to_replace:]
            end_index = first_placeholder + num_to_replace
            req_output_token_ids[first_placeholder:end_index] = new_ids

    def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
        """
        In async scheduling case, update spec_token_ids in sampling metadata with
        real draft token ids from prior step. This is called right before they are
        needed by the rejection sampler for penalty/bad_words computation.
        """
        if not draft_token_ids or not self.prev_req_id_to_index:
            return

        if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
            for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
                if spec_ids:
                    prev_index = self.prev_req_id_to_index.get(req_id)
                    if prev_index is not None:
                        draft_ids = draft_token_ids[prev_index]
                        if draft_ids:
                            del draft_ids[len(spec_ids) :]
                            spec_ids.clear()
                            spec_ids.extend(draft_ids)
1044

1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

1065
1066
    @property
    def no_penalties(self) -> bool:
1067
1068
1069
1070
1071
        return (
            len(self.presence_penalties_reqs) == 0
            and len(self.frequency_penalties_reqs) == 0
            and len(self.repetition_penalties_reqs) == 0
        )
1072

1073
    @property
1074
    def max_num_logprobs(self) -> int | None:
1075
        return max(self.num_logprobs.values()) if self.num_logprobs else None
1076

1077
1078
1079
    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0