gpu_input_batch.py 41.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# Datastructures defining a GPU input batch
4

5
from dataclasses import dataclass, field
6
from typing import Optional, cast
7
8
9
10

import numpy as np
import torch

11
from vllm import envs
12
from vllm.lora.request import LoRARequest
13
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
14
from vllm.pooling_params import PoolingParams
15
from vllm.sampling_params import SamplingParams, SamplingType
16
from vllm.utils import swap_dict_values
17
from vllm.v1.outputs import LogprobsTensors
18
from vllm.v1.pool.metadata import PoolingMetadata
19
20
21
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
                                             MoveDirectionality,
                                             init_builtin_logitsprocs)
22
from vllm.v1.sample.metadata import SamplingMetadata
23
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
24
from vllm.v1.utils import copy_slice
25
from vllm.v1.worker.block_table import MultiGroupBlockTable
26
27
28
29
30
31


@dataclass
class CachedRequestState:

    req_id: str
32
33
    prompt_token_ids: list[int]
    mm_inputs: list[MultiModalKwargs]
34
    mm_positions: list[PlaceholderRange]
35
36
    sampling_params: Optional[SamplingParams]
    pooling_params: Optional[PoolingParams]
37
38
    generator: Optional[torch.Generator]

39
    block_ids: tuple[list[int], ...]
40
    num_computed_tokens: int
laibao's avatar
laibao committed
41
    num_kv_tokens: int
42
    output_token_ids: list[int]
lizhigong's avatar
lizhigong committed
43
    spec_token_ids: list[int] = None
44

45
46
47
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[int] = None

48
    lora_request: Optional[LoRARequest] = None
49
50
51
52
53
    # Lazily populated when `VLLM_V1_FAST_TOKEN_ID_COPY` is enabled.
    _prompt_token_ids_np: Optional[np.ndarray] = field(default=None,
                                                       init=False,
                                                       repr=False,
                                                       compare=False)
54

55
56
57
    def __post_init__(self):
        self.num_prompt_tokens = len(self.prompt_token_ids)

58
59
    @property
    def num_tokens(self) -> int:
60
61
62
63
64
65
66
        return self.num_prompt_tokens + len(self.output_token_ids)

    def get_token_id(self, idx: int) -> int:
        if idx < self.num_prompt_tokens:
            return self.prompt_token_ids[idx]
        else:
            return self.output_token_ids[idx - self.num_prompt_tokens]
67
68
69
70
71


class InputBatch:

    def __init__(
72
73
74
75
76
77
78
79
        self,
        max_num_reqs: int,
        max_model_len: int,
        max_num_batched_tokens: int,
        device: torch.device,
        pin_memory: bool,
        vocab_size: int,
        block_sizes: list[int],  # The block_size of each kv cache group
80
        is_spec_decode: bool = False,
81
        logits_processing_needs_token_ids: bool = False,
82
    ):
83
        self.is_spec_decode = is_spec_decode
84
85
        self.max_num_reqs = max_num_reqs
        self.max_model_len = max_model_len
86
        self.max_num_batched_tokens = max_num_batched_tokens
87
88
        self.device = device
        self.pin_memory = pin_memory
89
        self.vocab_size = vocab_size
90
91
        self.logits_processing_needs_token_ids = (
            logits_processing_needs_token_ids)
92

93
94
        self._req_ids: list[Optional[str]] = []
        self.req_id_to_index: dict[str, int] = {}
95

96
97
        # TODO(woosuk): This buffer could be too large if max_model_len is big.
        # Find a way to reduce the CPU memory usage.
98
99
        # This buffer is not directly transferred to the GPU, so it does not
        # need to be pinned.
100
101
102
103
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len),
            device="cpu",
            dtype=torch.int32,
104
            pin_memory=False,
105
106
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
107
        self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
108
        self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
109
        self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
110
111
112
113
114
115
116
117
        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs, ),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_computed_tokens_cpu = \
            self.num_computed_tokens_cpu_tensor.numpy()
laibao's avatar
laibao committed
118
119
120
121
122
123
124
        self.num_kv_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs, ),
            device="cpu",
            dtype=torch.int32,
            pin_memory=pin_memory,
        )
        self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
125

126
        # Block table.
127
        self.block_table = MultiGroupBlockTable(
128
            max_num_reqs=max_num_reqs,
129
            max_model_len=max_model_len,
130
            max_num_batched_tokens=max_num_batched_tokens,
131
            pin_memory=pin_memory,
132
            device=device,
133
            block_sizes=block_sizes,
134
135
136
137
138
139
140
141
142
143
144
        )

        # Sampling-related.
        self.temperature = torch.empty((max_num_reqs, ),
                                       dtype=torch.float32,
                                       device=device)
        self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
                                                  dtype=torch.float32,
                                                  device="cpu",
                                                  pin_memory=pin_memory)
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
145
146
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()
147
148
149
150
151
152
153
154
155

        self.top_p = torch.empty((max_num_reqs, ),
                                 dtype=torch.float32,
                                 device=device)
        self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
156
        self.top_p_reqs: set[str] = set()
157
158
159
160
161
162
163
164
165

        self.top_k = torch.empty((max_num_reqs, ),
                                 dtype=torch.int32,
                                 device=device)
        self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=pin_memory)
        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
166
        self.top_k_reqs: set[str] = set()
167

168
169
        # IDs of requests which do not support spec decoding
        self.spec_decode_unsupported_reqs: set[str] = set()
170

171
172
173
174
175
176
177
178
179
180
        # Frequency penalty related data structures
        self.frequency_penalties = torch.empty((max_num_reqs, ),
                                               dtype=torch.float,
                                               device=device)
        self.frequency_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.frequency_penalties_cpu = \
181
            self.frequency_penalties_cpu_tensor.numpy()
182
        self.frequency_penalties_reqs: set[str] = set()
183
184
185
186
187
188
189
190
191

        # Presence penalty related data structures
        self.presence_penalties = torch.empty((max_num_reqs, ),
                                              dtype=torch.float,
                                              device=device)
        self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
                                                         dtype=torch.float,
                                                         device="cpu",
                                                         pin_memory=pin_memory)
192
193
        self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
        )
194
        self.presence_penalties_reqs: set[str] = set()
195
196
197
198
199
200
201
202
203
204
205

        # Repetition penalty related data structures
        self.repetition_penalties = torch.empty((max_num_reqs, ),
                                                dtype=torch.float,
                                                device=device)
        self.repetition_penalties_cpu_tensor = torch.empty(
            (max_num_reqs, ),
            dtype=torch.float,
            device="cpu",
            pin_memory=pin_memory)
        self.repetition_penalties_cpu = \
206
            self.repetition_penalties_cpu_tensor.numpy()
207
        self.repetition_penalties_reqs: set[str] = set()
208

209
210
211
212
        # Track whether sampling metadata is currently expanded to
        # per-token shape (spec decode reject path).
        self._sampling_metadata_is_expanded = False

213
214
215
        # lora related
        self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
                                             dtype=np.int32)
216
217
        self.lora_id_to_request_ids: dict[int, set[str]] = {}
        self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
218

219
        # req_index -> generator
220
221
        # NOTE(woosuk): The indices of the requests that do not have their own
        # generator should not be included in the dictionary.
222
        self.generators: dict[int, torch.Generator] = {}
223

224
        self.num_logprobs: dict[str, int] = {}
225
226
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
227
        self.num_prompt_logprobs: dict[str, int] = {}
228

229
230
231
        # To accumulate prompt logprobs tensor chunks across prefill steps.
        self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

232
233
234
235
236
237
238
239
240
241
242
243
244
245
        # Internal representation of per-step batch state changes, used for
        # reordering persistent batch and generating logitsprocs batch state
        # updates. Should reset each step.
        self.batch_update_builder = BatchUpdateBuilder()

        # Define logits processors.
        # TODO(andy): logits processor list should be extensible via engine
        # constructor argument; for now the list is fixed.
        self.logitsprocs = init_builtin_logitsprocs(
            pin_memory_available=pin_memory,
            max_num_reqs=max_num_reqs + 1,
            device=device)

        # TODO convert this to LogitsProcessor
246
        self.has_allowed_token_ids: set[str] = set()
247
248
        # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
        # the value is False. Since we use masked_fill_ to set -inf.
249
250
        self.allowed_token_ids_mask: Optional[torch.Tensor] = None
        self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
251

252
253
254
        # req_index -> bad_words_token_ids
        self.bad_words_token_ids: dict[int, list[list[int]]] = {}

255
        self.req_output_token_ids: list[Optional[list[int]]] = []
256
257
258
259

        # This is updated each time the batch constituents change.
        self.sampling_metadata = self._make_sampling_metadata()

260
261
        self.pooling_params: dict[str, PoolingParams] = {}

262
    @property
263
    def req_ids(self) -> list[str]:
264
265
        # None elements should only be present transiently
        # while performing state updates to the batch.
266
        return cast(list[str], self._req_ids)
267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    def _get_next_add_index(self) -> int:
        if (req_index := self.batch_update_builder.pop_removed()) is not None:
            # Fill the empty index.
            return req_index
        # Append to end
        return self.num_reqs

    def _register_add_request(self, request: "CachedRequestState") -> int:
        """Track add-request operations"""
        req_index = self._get_next_add_index()
        assert req_index < self.max_num_reqs
        params = (request.sampling_params
                  if request.sampling_params else request.pooling_params)
        self.batch_update_builder.added.append(
            (req_index, params, request.output_token_ids))
        return req_index

285
286
287
    def add_request(
        self,
        request: "CachedRequestState",
288
289
    ) -> int:
        req_index = self._register_add_request(request)
290
291

        req_id = request.req_id
292
293
294
295
296
297
298
        if req_index == len(self._req_ids):
            self._req_ids.append(req_id)
            self.req_output_token_ids.append(request.output_token_ids)
        else:
            self._req_ids[req_index] = req_id
            self.req_output_token_ids[req_index] = request.output_token_ids

299
300
301
        self.req_id_to_index[req_id] = req_index

        # Copy the prompt token ids and output token ids.
302
303
304
        # OPTIMIZATION: Use np.copyto with pre-converted NumPy arrays
        # instead of slice assignment. This avoids the ~550 µs overhead
        # of converting Python list to NumPy array each time.
305
        num_prompt_tokens = len(request.prompt_token_ids)
306
        self.num_prompt_tokens[req_index] = num_prompt_tokens
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
            self.token_ids_cpu[
                req_index, :num_prompt_tokens] = request.prompt_token_ids
            start_idx = num_prompt_tokens
            end_idx = start_idx + len(request.output_token_ids)
            self.token_ids_cpu[req_index,
                            start_idx:end_idx] = request.output_token_ids
        else:
            prompt_token_ids_np = getattr(request, "_prompt_token_ids_np", None)
            rebuild_prompt_cache = True
            if prompt_token_ids_np is not None:
                try:
                    rebuild_prompt_cache = (prompt_token_ids_np.dtype != np.int32
                                            or prompt_token_ids_np.size !=
                                            num_prompt_tokens)
                except Exception:
                    rebuild_prompt_cache = True
            if rebuild_prompt_cache:
                prompt_token_ids_np = np.asarray(request.prompt_token_ids,
                                                dtype=np.int32)
                try:
                    request._prompt_token_ids_np = prompt_token_ids_np
                except Exception:
                    pass
            np.copyto(
                self.token_ids_cpu[req_index, :num_prompt_tokens],
                prompt_token_ids_np,
                casting='no',
            )
            start_idx = num_prompt_tokens
            output_token_ids_np = np.asarray(request.output_token_ids,
                                            dtype=np.int32)
            end_idx = start_idx + output_token_ids_np.size
            np.copyto(
                self.token_ids_cpu[req_index, start_idx:end_idx],
                output_token_ids_np,
                casting='no',
            )

lizhigong's avatar
lizhigong committed
346
347
348
349
350
351
        num_spec_tokens = 0
        if request.spec_token_ids != None:
            num_spec_tokens = len(request.spec_token_ids)
            self.token_ids_cpu[req_index,
                            end_idx:end_idx + num_spec_tokens] = request.spec_token_ids

352
353
        # Number of token ids in token_ids_cpu.
        # NOTE(woosuk): This may include spec decode tokens.
lizhigong's avatar
lizhigong committed
354
        self.num_tokens[req_index] = request.num_tokens + num_spec_tokens
355
356
        # Number of tokens without spec decode tokens.
        self.num_tokens_no_spec[req_index] = request.num_tokens
357
358

        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
laibao's avatar
laibao committed
359
        self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
360
        self.block_table.add_row(request.block_ids, req_index)
361

362
        if sampling_params := request.sampling_params:
363
364
365
            if (self.is_spec_decode
                    and is_spec_decode_unsupported(sampling_params)):
                self.spec_decode_unsupported_reqs.add(req_id)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
            if sampling_params.sampling_type == SamplingType.GREEDY:
                # Avoid later division by zero.
                self.temperature_cpu[req_index] = -1.0
                self.greedy_reqs.add(req_id)
            else:
                self.temperature_cpu[req_index] = sampling_params.temperature
                self.random_reqs.add(req_id)

            self.top_p_cpu[req_index] = sampling_params.top_p
            if sampling_params.top_p < 1:
                self.top_p_reqs.add(req_id)
            top_k = sampling_params.top_k
            if 0 < top_k < self.vocab_size:
                self.top_k_reqs.add(req_id)
            else:
                top_k = self.vocab_size
            self.top_k_cpu[req_index] = top_k
            self.frequency_penalties_cpu[
                req_index] = sampling_params.frequency_penalty
            if sampling_params.frequency_penalty != 0.0:
                self.frequency_penalties_reqs.add(req_id)
            self.presence_penalties_cpu[
                req_index] = sampling_params.presence_penalty
            if sampling_params.presence_penalty != 0.0:
                self.presence_penalties_reqs.add(req_id)
            self.repetition_penalties_cpu[
                req_index] = sampling_params.repetition_penalty
            if sampling_params.repetition_penalty != 1.0:
                self.repetition_penalties_reqs.add(req_id)

            # NOTE(woosuk): self.generators should not include the requests that
            # do not have their own generator.
            if request.generator is not None:
                self.generators[req_index] = request.generator

            if sampling_params.logprobs is not None:
                self.num_logprobs[req_id] = sampling_params.logprobs
            if sampling_params.prompt_logprobs is not None:
                self.num_prompt_logprobs[
                    req_id] = sampling_params.prompt_logprobs

            if sampling_params.allowed_token_ids:
                self.has_allowed_token_ids.add(req_id)
                if self.allowed_token_ids_mask_cpu_tensor is None:
                    # Lazy allocation for this tensor, which can be large.
                    # False means we don't fill with -inf.
                    self.allowed_token_ids_mask = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device=self.device)
                    self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
                        self.max_num_reqs,
                        self.vocab_size,
                        dtype=torch.bool,
                        device="cpu")
                self.allowed_token_ids_mask_cpu_tensor[req_index] = True
423
                # False means we don't fill with -inf.
424
425
                self.allowed_token_ids_mask_cpu_tensor[req_index][
                    sampling_params.allowed_token_ids] = False
426

427
428
429
430
431
432
            if sampling_params.bad_words_token_ids:
                self.bad_words_token_ids[
                    req_index] = sampling_params.bad_words_token_ids
        else:
            assert request.pooling_params is not None
            self.pooling_params[req_id] = request.pooling_params
433

434
435
436
437
438
439
440
441
442
443
444
445
446
        # Add request lora ID
        if request.lora_request:
            lora_id = request.lora_request.lora_int_id
            if lora_id not in self.lora_id_to_request_ids:
                self.lora_id_to_request_ids[lora_id] = set()

            self.request_lora_mapping[req_index] = lora_id
            self.lora_id_to_request_ids[lora_id].add(request.req_id)
            self.lora_id_to_lora_request[lora_id] = request.lora_request
        else:
            # No LoRA
            self.request_lora_mapping[req_index] = 0

447
448
        return req_index

449
    def remove_request(self, req_id: str) -> Optional[int]:
450
451
452
453
454
455
456
457
        """This method must always be followed by a call to condense().
        
        Args:
          req_id: request to remove

        Returns:
          Removed request index, or `None` if `req_id` not recognized
        """
458

459
460
461
        req_index = self.req_id_to_index.pop(req_id, None)
        if req_index is None:
            return None
462
        self.batch_update_builder.removed_append(req_index)
463
464
        self._req_ids[req_index] = None
        self.req_output_token_ids[req_index] = None
465
466
467
468
469

        self.greedy_reqs.discard(req_id)
        self.random_reqs.discard(req_id)
        self.top_p_reqs.discard(req_id)
        self.top_k_reqs.discard(req_id)
470
        self.spec_decode_unsupported_reqs.discard(req_id)
471
472
473
        self.frequency_penalties_reqs.discard(req_id)
        self.presence_penalties_reqs.discard(req_id)
        self.repetition_penalties_reqs.discard(req_id)
474
475
        self.generators.pop(req_index, None)
        self.num_logprobs.pop(req_id, None)
476
        self.num_prompt_logprobs.pop(req_id, None)
477
        self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
478
479
480
481
482
483
484
485
486
487

        # LoRA
        lora_id = self.request_lora_mapping[req_index]
        if lora_id != 0:
            self.lora_id_to_request_ids[lora_id].discard(req_id)
            if len(self.lora_id_to_request_ids[lora_id]) == 0:
                self.lora_id_to_request_ids.pop(lora_id)
                self.lora_id_to_lora_request.pop(lora_id)
            self.request_lora_mapping[req_index] = 0

488
489
        self.has_allowed_token_ids.discard(req_id)
        if self.allowed_token_ids_mask_cpu_tensor is not None:
490
            # False means we don't fill with -inf.
491
            self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
492
        self.bad_words_token_ids.pop(req_index, None)
493
        self.pooling_params.pop(req_id, None)
494
495
        return req_index

496
    def swap_states(self, i1: int, i2: int) -> None:
497
498
        self.batch_update_builder.moved.append(
            (i1, i2, MoveDirectionality.SWAP))
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        old_id_i1 = self._req_ids[i1]
        old_id_i2 = self._req_ids[i2]
        self._req_ids[i1], self._req_ids[i2] =\
            self._req_ids[i2], self._req_ids[i1] # noqa
        self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
            self.req_output_token_ids[i2], self.req_output_token_ids[i1]
        assert old_id_i1 is not None and old_id_i2 is not None
        self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
            self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
        self.num_tokens[i1], self.num_tokens[i2] =\
            self.num_tokens[i2], self.num_tokens[i1]
        self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
            self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
        self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
            self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
        self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
            self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
laibao's avatar
laibao committed
516
517
        self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\
            self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1]
518
519
520
521
522
523
524
525
526
527
528
529
530
        self.temperature_cpu[i1], self.temperature_cpu[i2] =\
            self.temperature_cpu[i2], self.temperature_cpu[i1]
        self.top_p_cpu[i1], self.top_p_cpu[i2] =\
            self.top_p_cpu[i2], self.top_p_cpu[i1]
        self.top_k_cpu[i1], self.top_k_cpu[i2] =\
            self.top_k_cpu[i2], self.top_k_cpu[i1]
        self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
            self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
        self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
            self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
        self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
            self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]

531
532
533
534
535
536
537
538
539
        # NOTE: the following is unsafe
        # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
        #     self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
        # instead, we need to temporiarily copy the data for one of the indices
        # TODO(lucas): optimize this by only copying valid indices
        tmp = self.token_ids_cpu[i1, ...].copy()
        self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
        self.token_ids_cpu[i2, ...] = tmp

540
541
        swap_dict_values(self.generators, i1, i2)
        swap_dict_values(self.bad_words_token_ids, i1, i2)
542
543
544

        self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
            self.request_lora_mapping[i2], self.request_lora_mapping[i1]
545
546
547
548
549
550

        if self.allowed_token_ids_mask_cpu_tensor is not None:
            self.allowed_token_ids_mask_cpu_tensor[i1], \
                self.allowed_token_ids_mask_cpu_tensor[i2] =\
                self.allowed_token_ids_mask_cpu_tensor[i2], \
                    self.allowed_token_ids_mask_cpu_tensor[i1]
551
552
        self.block_table.swap_row(i1, i2)

553
554
555
556
557
558
    def condense(self) -> None:
        """Slide non-empty requests down into lower, empty indices.

        Any consecutive empty indices at the very end of the list are not
        filled.

559
        Args:
560
561
562
563
564
          empty_req_indices: empty indices which may be filled.

        Returns:
          swaps: list of (from,to) swap tuples for moved requests
          empty_req_indices: indices not filled by condensation
565
        """
566
567
568
569
        if not (empty_req_indices := self.batch_update_builder.removed):
            # All removed requests were replaced by added requests, or else no
            # requests were removed at all. No condense() needed
            return
570
571
        num_reqs = self.num_reqs
        if num_reqs == 0:
572
            # The batched states are empty.
573
574
            self._req_ids.clear()
            self.req_output_token_ids.clear()
575
576
577
578
            return

        # NOTE(woosuk): This function assumes that the empty_req_indices
        # is sorted in descending order.
579
        last_req_index = num_reqs + len(empty_req_indices) - 1
580
581
582
583
584
585
        while empty_req_indices:
            # Find the largest non-empty index.
            while last_req_index in empty_req_indices:
                last_req_index -= 1

            # Find the smallest empty index.
586
587
            empty_index = self.batch_update_builder.peek_removed()
            assert empty_index is not None
588
589
590
            if empty_index >= last_req_index:
                break

591
592
593
594
595
596
            # Move active request down into empty request
            # index.
            self.batch_update_builder.pop_removed()
            self.batch_update_builder.moved.append(
                (last_req_index, empty_index,
                 MoveDirectionality.UNIDIRECTIONAL))
597
598
            req_id = self._req_ids[last_req_index]
            output_token_ids = self.req_output_token_ids[last_req_index]
599
            assert req_id is not None
600
601
602
603
            self._req_ids[empty_index] = req_id
            self._req_ids[last_req_index] = None
            self.req_output_token_ids[empty_index] = output_token_ids
            self.req_output_token_ids[last_req_index] = None
604
605
            self.req_id_to_index[req_id] = empty_index

606
607
608
609
            num_tokens = self.num_tokens[last_req_index]
            self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
                last_req_index, :num_tokens]
            self.num_tokens[empty_index] = num_tokens
610
611
            self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
                last_req_index]
612
613
            self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
                last_req_index]
614
615
            self.num_computed_tokens_cpu[
                empty_index] = self.num_computed_tokens_cpu[last_req_index]
laibao's avatar
laibao committed
616
617
            self.num_kv_tokens_cpu[
                empty_index] = self.num_kv_tokens_cpu[last_req_index]
618
            self.block_table.move_row(last_req_index, empty_index)
619
620
621
622
            self.temperature_cpu[empty_index] = self.temperature_cpu[
                last_req_index]
            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
623
624
625
626
627
628
            self.frequency_penalties_cpu[
                empty_index] = self.frequency_penalties_cpu[last_req_index]
            self.presence_penalties_cpu[
                empty_index] = self.presence_penalties_cpu[last_req_index]
            self.repetition_penalties_cpu[
                empty_index] = self.repetition_penalties_cpu[last_req_index]
629
630
631
632
            generator = self.generators.pop(last_req_index, None)
            if generator is not None:
                self.generators[empty_index] = generator

633
634
635
            self.request_lora_mapping[empty_index] = self.request_lora_mapping[
                last_req_index]

636
            # TODO convert these to LogitsProcessors
637
638
639
640
641
            if self.allowed_token_ids_mask_cpu_tensor is not None:
                self.allowed_token_ids_mask_cpu_tensor[
                    empty_index] = self.allowed_token_ids_mask_cpu_tensor[
                        last_req_index]

642
643
644
645
            bad_words_token_ids = self.bad_words_token_ids.pop(
                last_req_index, None)
            if bad_words_token_ids is not None:
                self.bad_words_token_ids[empty_index] = bad_words_token_ids
646

647
648
649
            # Decrement last_req_index since it is now empty.
            last_req_index -= 1

650
651
652
653
        # Trim lists to the batch size.
        del self._req_ids[self.num_reqs:]
        del self.req_output_token_ids[self.num_reqs:]

654
    def refresh_metadata(self, repeat_counts: Optional[torch.Tensor] = None):
655
656
657
658
659
660
661
662
        """Apply batch updates, reset input batch at end of step
        
        * Apply batch add/remove/permute to logits procs' states
        * If batch state is modified, update sampling metadata
        """
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)
663
664
665
666
667
668
669
670
671
672
673
        needs_rebuild = (batch_update
                         or repeat_counts is not None
                         or self._sampling_metadata_is_expanded)
        if needs_rebuild:
            if repeat_counts is None:
                self.sampling_metadata = self._make_sampling_metadata()
            else:
                self.sampling_metadata = self._make_sampling_metadata_expanded(
                    repeat_counts)
            self._sampling_metadata_is_expanded = repeat_counts is not None
        # Expanded metadata is built on demand; do not cache a copy here.
674
675
676

    def _make_sampling_metadata(self) -> SamplingMetadata:
        num_reqs = self.num_reqs
677
678
679
680
681
        if not self.all_greedy:
            temperature = copy_slice(self.temperature_cpu_tensor,
                                     self.temperature, num_reqs)
        else:
            temperature = None
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
        if not self.no_top_p:
            copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
        if not self.no_top_k:
            copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

        if not self.no_penalties:
            # Since syncing these tensors is expensive only copy them
            # if necessary i.e. if there are requests which require
            # penalties to be applied during sampling.
            copy_slice(self.frequency_penalties_cpu_tensor,
                       self.frequency_penalties, num_reqs)
            copy_slice(self.presence_penalties_cpu_tensor,
                       self.presence_penalties, num_reqs)
            copy_slice(self.repetition_penalties_cpu_tensor,
                       self.repetition_penalties, num_reqs)

698
699
700
701
702
703
704
705
        needs_prompt_token_ids = (not self.no_penalties or
                                  (self.num_reqs > 0
                                   and self.logits_processing_needs_token_ids))
        if needs_prompt_token_ids:
            # The prompt tokens are used only for applying penalties or
            # step pooling during the sampling/pooling process.
            # Hence copy these tensors only when there are requests which
            # need penalties/step_pooler to be applied.
706
707
708
            prompt_token_ids = self._make_prompt_token_ids_tensor()
        else:
            prompt_token_ids = None
709

710
711
712
713
714
715
716
        allowed_token_ids_mask: Optional[torch.Tensor] = None
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
            copy_slice(self.allowed_token_ids_mask_cpu_tensor,
                       self.allowed_token_ids_mask, num_reqs)
            allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

717
718
719
720
721
722
723
724
725
        # Host-side summaries to avoid device synchronization in sampling
        # fast paths (e.g. reduced top-k/top-p sampling).
        max_top_k = None
        has_any_no_top_k = False
        if not self.no_top_k and num_reqs > 0:
            top_k_cpu = self.top_k_cpu[:num_reqs]
            max_top_k = int(top_k_cpu.max())
            has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())

726
        return SamplingMetadata(
727
            temperature=temperature,
728
729
            all_greedy=self.all_greedy,
            all_random=self.all_random,
730
731
            top_p=None if self.no_top_p else self.top_p[:num_reqs],
            top_k=None if self.no_top_k else self.top_k[:num_reqs],
732
733
            generators=self.generators,
            max_num_logprobs=self.max_num_logprobs,
734
735
736
737
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=self.frequency_penalties[:num_reqs],
            presence_penalties=self.presence_penalties[:num_reqs],
            repetition_penalties=self.repetition_penalties[:num_reqs],
738
            output_token_ids=cast(list[list[int]], self.req_output_token_ids),
739
            no_penalties=self.no_penalties,
740
            allowed_token_ids_mask=allowed_token_ids_mask,
741
            bad_words_token_ids=self.bad_words_token_ids,
742
            logitsprocs=self.logitsprocs,
743
744
            max_top_k=max_top_k,
            has_any_no_top_k=has_any_no_top_k,
745
746
        )

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
    def _make_sampling_metadata_expanded(
        self, repeat_counts: torch.Tensor
    ) -> SamplingMetadata:
        num_reqs = self.num_reqs
        repeat_counts_cpu = repeat_counts
        all_greedy = self.all_greedy
        all_random = self.all_random
        # For reject-sampling optimization, force greedy sampling to keep
        # rejection sampler assumptions (per-request shapes) intact.

        def _expand_cpu_to_gpu(
            t: Optional[torch.Tensor],
            *,
            dtype: Optional[torch.dtype] = None,
        ) -> Optional[torch.Tensor]:
            if t is None:
                return None
            base = t[:num_reqs]
            if repeat_counts_cpu is not None:
                base = base.repeat_interleave(repeat_counts_cpu, dim=0)
            return base.to(device=self.device,
                           dtype=dtype if dtype is not None else None,
                           non_blocking=True)

        needs_prompt_token_ids = (not self.no_penalties or
                                  (self.num_reqs > 0
                                   and self.logits_processing_needs_token_ids))
        if needs_prompt_token_ids:
            # The prompt tokens are used only for applying penalties or
            # step pooling during the sampling/pooling process.
            # Hence copy these tensors only when there are requests which
            # need penalties/step_pooler to be applied.
            prompt_token_ids = self._make_prompt_token_ids_tensor(
                repeat_counts_cpu)
        else:
            prompt_token_ids = None

        allowed_token_ids_mask: Optional[torch.Tensor] = None
        if not self.no_allowed_token_ids:
            assert self.allowed_token_ids_mask is not None
            allowed_token_ids_mask = self.allowed_token_ids_mask_cpu_tensor
        # Expand per-request metadata to per-token shape when repeat_counts
        # is provided (spec decode reject-sampling path).
        top_p_cpu = None if self.no_top_p else self.top_p_cpu_tensor
        top_k_cpu = None if self.no_top_k else self.top_k_cpu_tensor

        repeat_list = repeat_counts_cpu.tolist()
        row_offsets: list[int] = []
        total_rows = 0
        for repeat in repeat_list:
            row_offsets.append(total_rows)
            total_rows += int(repeat)
        expanded_output_token_ids: list[list[int]] = []
        expanded_bad_words_token_ids: dict[int, list[list[int]]] = {}
        expanded_generators: dict[int, torch.Generator] = {}
        row_idx = 0
        for req_idx in range(num_reqs):
            repeat = int(repeat_list[req_idx])
            if repeat <= 0:
                continue
            output_tokens = self.req_output_token_ids[req_idx]
            assert output_tokens is not None
            bad_words = self.bad_words_token_ids.get(req_idx)
            generator = self.generators.get(req_idx)
            for _ in range(repeat):
                expanded_output_token_ids.append(output_tokens)
                if bad_words is not None:
                    expanded_bad_words_token_ids[row_idx] = bad_words
                if generator is not None:
                    expanded_generators[row_idx] = generator
                row_idx += 1

819
820
821
822
823
824
825
        max_top_k = None
        has_any_no_top_k = False
        if not self.no_top_k and num_reqs > 0:
            top_k_cpu = self.top_k_cpu[:num_reqs]
            max_top_k = int(top_k_cpu.max())
            has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        return SamplingMetadata(
            temperature=_expand_cpu_to_gpu(
                None if all_greedy else self.temperature_cpu_tensor),
            all_greedy=all_greedy,
            all_random=all_random,
            top_p=_expand_cpu_to_gpu(top_p_cpu),
            top_k=_expand_cpu_to_gpu(top_k_cpu, dtype=torch.int32),
            generators=expanded_generators,
            max_num_logprobs=self.max_num_logprobs,
            prompt_token_ids=prompt_token_ids,
            frequency_penalties=(
                None if self.no_penalties else _expand_cpu_to_gpu(
                    self.frequency_penalties_cpu_tensor)),
            presence_penalties=(
                None if self.no_penalties else _expand_cpu_to_gpu(
                    self.presence_penalties_cpu_tensor)),
            repetition_penalties=(
                None if self.no_penalties else _expand_cpu_to_gpu(
                    self.repetition_penalties_cpu_tensor)),
            output_token_ids=expanded_output_token_ids,
            no_penalties=self.no_penalties,
            allowed_token_ids_mask=_expand_cpu_to_gpu(
                allowed_token_ids_mask, dtype=torch.bool),
            bad_words_token_ids=expanded_bad_words_token_ids,
850
            logitsprocs=self.logitsprocs,
851
852
            max_top_k=max_top_k,
            has_any_no_top_k=has_any_no_top_k,
853
854
        )

855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
    @property
    def pooling_metadata(self) -> PoolingMetadata:
        if len(self.pooling_params) == 0:
            pooling_params = []
        else:
            # Note, for now this assumes that all request in the batch
            # are either sampling or pooling requests
            assert len(self.req_ids) == len(self.pooling_params)
            pooling_params = [
                self.pooling_params[req_id] for req_id in self.req_ids
            ]

        return PoolingMetadata(
            prompt_lens=torch.from_numpy(
                self.num_prompt_tokens[:self.num_reqs]).to(self.device),
            prompt_token_ids=self.sampling_metadata.prompt_token_ids,
            pooling_params=pooling_params,
        )

874
875
876
    def _make_prompt_token_ids_tensor(
        self, repeat_counts_cpu: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
877
878
879
880
881
        max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
        prompt_token_ids_cpu_tensor = torch.empty(
            (self.num_reqs, max_prompt_len),
            device="cpu",
            dtype=torch.int64,
882
883
            pin_memory=self.pin_memory,
        )
884
        prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
885
886
        prompt_token_ids[:] = self.token_ids_cpu[:self.
                                                 num_reqs, :max_prompt_len]
887
888
889
890
        # Use the value of vocab_size as a pad since we don't have a
        # token_id of this value.
        for i in range(self.num_reqs):
            prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
891
892
893
        if repeat_counts_cpu is not None:
            prompt_token_ids_cpu_tensor = prompt_token_ids_cpu_tensor \
                .repeat_interleave(repeat_counts_cpu, dim=0)
894
895
896
        return prompt_token_ids_cpu_tensor.to(device=self.device,
                                              non_blocking=True)

897
898
    def make_lora_inputs(
        self, num_scheduled_tokens: np.ndarray
899
    ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        """
        Given the num_scheduled_tokens for each request in the batch, return
        datastructures used to activate the current LoRAs.
        Returns:
            1. prompt_lora_mapping: A tuple of size self.num_reqs where,
               prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
            2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
               where, token_lora_mapping[i] is the LoRA id to use for ith token.
            3. lora_requests: Set of relevant LoRA requests.
        """

        req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
        prompt_lora_mapping = tuple(req_lora_mapping)
        token_lora_mapping = tuple(
            req_lora_mapping.repeat(num_scheduled_tokens))
915
        active_lora_requests: set[LoRARequest] = set(
916
917
918
919
            self.lora_id_to_lora_request.values())

        return prompt_lora_mapping, token_lora_mapping, active_lora_requests

920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
    @property
    def num_reqs(self) -> int:
        return len(self.req_id_to_index)

    @property
    def all_greedy(self) -> bool:
        return len(self.random_reqs) == 0

    @property
    def all_random(self) -> bool:
        return len(self.greedy_reqs) == 0

    @property
    def no_top_p(self) -> bool:
        return len(self.top_p_reqs) == 0

    @property
    def no_top_k(self) -> bool:
        return len(self.top_k_reqs) == 0

940
941
942
943
944
945
    @property
    def no_penalties(self) -> bool:
        return (len(self.presence_penalties_reqs) == 0
                and len(self.frequency_penalties_reqs) == 0
                and len(self.repetition_penalties_reqs) == 0)

946
    @property
947
948
    def max_num_logprobs(self) -> Optional[int]:
        return max(self.num_logprobs.values()) if self.num_logprobs else None
949
950
951

    @property
    def no_prompt_logprob(self) -> bool:
952
        return not self.num_prompt_logprobs
953
954
955
956

    @property
    def no_allowed_token_ids(self) -> bool:
        return len(self.has_allowed_token_ids) == 0