gpu_model_runner.py 149 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import gc
6
import itertools
7
import time
8
9
from collections import defaultdict
from collections.abc import Iterator
10
from contextlib import contextmanager
11
from copy import deepcopy
12
from typing import TYPE_CHECKING, Any, Optional, Union, cast
13
14
15
16
17

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
18
from tqdm import tqdm
19

20
import vllm.envs as envs
21
from vllm.attention import Attention, AttentionType
22
from vllm.attention.backends.abstract import AttentionBackend
23
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
24
from vllm.compilation.counter import compilation_counter
25
26
27
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
28
                         get_layers_from_vllm_config, update_config)
29
from vllm.distributed.eplb.eplb_state import EplbState
30
31
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group)
32
from vllm.distributed.parallel_state import (
33
    get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
34
    prepare_communication_buffer_for_model)
35
36
from vllm.forward_context import (BatchDescriptor, DPMetadata,
                                  set_forward_context)
37
from vllm.logger import init_logger
38
39
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
40
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
41
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
42
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
43
                                                   supports_eagle3,
44
45
46
                                                   supports_transcription)
from vllm.model_executor.models.interfaces_base import (
    VllmModelForPooling, is_pooling_model, is_text_generation_model)
47
from vllm.multimodal import MULTIMODAL_REGISTRY
48
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
49
                                    PlaceholderRange)
50
from vllm.multimodal.utils import group_mm_kwargs_by_modality
51
from vllm.pooling_params import PoolingParams
52
from vllm.sampling_params import SamplingType
53
from vllm.sequence import IntermediateTensors, PoolerOutput
54
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
55
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
56
57
58
                        GiB_bytes, LazyLoader, cdiv, check_use_alibi,
                        get_dtype_size, is_pin_memory_available, round_up,
                        supports_dynamo)
59
from vllm.v1.attention.backends.utils import (
60
    AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
61
    make_kv_sharing_fast_prefill_attention_metadata,
62
    reorder_batch_to_split_decodes_and_prefills)
63
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
64
65
from vllm.v1.kv_cache_interface import (AttentionSpec,
                                        ChunkedLocalAttentionSpec,
66
                                        EncoderOnlyAttentionSpec,
67
                                        FullAttentionSpec, KVCacheConfig,
68
69
                                        KVCacheGroupSpec, KVCacheSpec,
                                        MambaSpec, SlidingWindowSpec)
70
71
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
                             LogprobsTensors, ModelRunnerOutput)
72
from vllm.v1.pool.metadata import PoolingMetadata
73
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
74
from vllm.v1.sample.metadata import SamplingMetadata
75
from vllm.v1.sample.rejection_sampler import RejectionSampler
76
from vllm.v1.sample.sampler import Sampler
77
from vllm.v1.spec_decode.eagle import EagleProposer
78
from vllm.v1.spec_decode.medusa import MedusaProposer
79
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
80
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
81
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
82
from vllm.v1.worker.kv_connector_model_runner_mixin import (
83
    KVConnectorModelRunnerMixin, KVConnectorOutput)
84
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
85

86
87
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
                    gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
88
                    sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
89

90
if TYPE_CHECKING:
91
    import xgrammar as xgr
92
    import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile  # noqa: E501
93

94
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
95
    from vllm.v1.core.sched.output import SchedulerOutput
96
97
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
98
99
100
    xgr_torch_compile = LazyLoader(
        "xgr_torch_compile", globals(),
        "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile")
101
102
103
104

logger = init_logger(__name__)


105
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
106
107
108

    def __init__(
        self,
109
        vllm_config: VllmConfig,
110
        device: torch.device,
111
    ):
112
113
114
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
115
        self.compilation_config = vllm_config.compilation_config
116
117
118
119
120
121
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
122

123
124
125
126
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

127
128
129
130
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
131
        self.device = device
132
133
134
135
136
137
138
139
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

140
        self.is_pooling_model = model_config.pooler_config is not None
141
142
        self.is_multimodal_raw_input_supported = (
            model_config.is_multimodal_raw_input_supported)
143
144
        self.max_model_len = model_config.max_model_len
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
145
        self.max_num_reqs = scheduler_config.max_num_seqs
146
147

        # Model-related.
148
149
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
150
        self.hidden_size = model_config.get_hidden_size()
151
        self.attention_chunk_size = model_config.attention_chunk_size
152

153
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
154

155
        # Multi-modal data support
156
        self.mm_registry = MULTIMODAL_REGISTRY
157
        self.uses_mrope = model_config.uses_mrope
158
159
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            model_config)
160

161
        # Sampler
162
        self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
163

164
165
166
167
168
169
170
        self.eplb_state: Optional[EplbState] = None
        """
        State of the expert parallelism load balancer.

        Will be lazily initialized when the model is loaded.
        """

171
        # Lazy initializations
172
        # self.model: nn.Module  # Set after load_model
173
        # Initialize in initialize_kv_cache
174
        self.kv_caches: list[torch.Tensor] = []
175
176
        # indexes: [kv_cache_group_id][attn_group]
        self.attn_groups: list[list[AttentionGroup]] = []
177
178
        # self.kv_cache_config: KVCacheConfig

179
180
        # mm_hash ->  encoder_output
        self.encoder_cache: dict[str, torch.Tensor] = {}
181

182
        self.use_aux_hidden_state_outputs = False
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        # Set up speculative decoding.
        # NOTE(Jiayi): currently we put the entire draft model on
        # the last PP rank. This is not ideal if there are many
        # layers in the draft model.
        if self.speculative_config and get_pp_group().is_last_rank:
            if self.speculative_config.method == "ngram":
                self.drafter = NgramProposer(self.vllm_config)
            elif self.speculative_config.use_eagle():
                self.drafter = EagleProposer(self.vllm_config, self.device,
                                             self)  # type: ignore
                if self.speculative_config.method == "eagle3":
                    self.use_aux_hidden_state_outputs = True
            elif self.speculative_config.method == "medusa":
                self.drafter = MedusaProposer(
                    vllm_config=self.vllm_config,
                    device=self.device)  # type: ignore
            else:
                raise ValueError("Unknown speculative decoding method: "
                                 f"{self.speculative_config.method}")
            self.rejection_sampler = RejectionSampler()
203

204
        # Request states.
205
        self.requests: dict[str, CachedRequestState] = {}
206

207
208
209
210
211
212
213
214
215
        # Input Batch
        # NOTE(Chen): Ideally, we should initialize the input batch inside
        # `initialize_kv_cache` based on the kv cache config. However, as in
        # https://github.com/vllm-project/vllm/pull/18298, due to some unknown
        # reasons, we have to initialize the input batch before `load_model`,
        # quantization + weight offloading will fail otherwise. As a temporary
        # solution, we initialize the input batch here, and re-initialize it
        # in `initialize_kv_cache` if the block_sizes here is different from
        # the block_sizes in the kv cache config.
216
217
218
219
220
221
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
222
            vocab_size=self.model_config.get_vocab_size(),
223
            block_sizes=[self.cache_config.block_size],
224
            is_spec_decode=bool(self.vllm_config.speculative_config),
225
226
227
228
229
            logitsprocs=build_logitsprocs(
                self.vllm_config, self.device, self.pin_memory,
                self.is_pooling_model,
                self.vllm_config.model_config.logits_processors),
            is_pooling_model=self.is_pooling_model,
230
        )
231

232
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
233
234
235
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
236
237
238
239
        if self.compilation_config.cudagraph_capture_sizes and \
                self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
            self.cudagraph_batch_sizes = list(
                reversed(self.compilation_config.cudagraph_capture_sizes))
240

241
        # Cache the device properties.
242
        self._init_device_properties()
243

244
245
246
247
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
248
249
250
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
251
252
253
254
255
256
257
        self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
                                           dtype=torch.int32,
                                           device=self.device)
        self.seq_lens = torch.zeros(self.max_num_reqs,
                                    dtype=torch.int32,
                                    device=self.device)

258
259
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
260
261

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
262
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
263
264
265
266
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
267
268
269
270
271
272

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
273
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
274
275
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
276
277
278
279
280
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
281
            self.mrope_positions_np = self.mrope_positions_cpu.numpy()
282

283
284
285
        # Only relevant for models using ALiBi (e.g, MPT)
        self.use_alibi = check_use_alibi(model_config)

286
287
288
289
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
290

291
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
292
        # Keep in int64 to avoid overflow with long context
293
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
294
295
                                       self.max_model_len,
                                       self.max_num_tokens),
296
                                   dtype=np.int64)
297
298
299
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
314
315
316
317
318
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
319

320
321
322
323
324
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}
325
326
327
328
329
330
        self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()

        self.kv_sharing_fast_prefill_logits_indices = None
        if self.cache_config.kv_sharing_fast_prefill:
            self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
                self.max_num_tokens, dtype=torch.int32, device=self.device)
331

332
333
334
335
336
337
        self.uniform_decode_query_len = 1 if not self.speculative_config else \
            1 + self.speculative_config.num_speculative_tokens

        # Cudagraph dispatcher for runtime cudagraph dispatching.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)

338
339
340
341
        self.mm_budget = (MultiModalBudget(
            self.model_config,
            self.scheduler_config,
            self.mm_registry,
342
        ) if self.supports_mm_inputs else None)
343

344
345
        self.reorder_batch_threshold: Optional[int] = None

346
347
348
349
350
        # Attention layers that are only in the KVCacheConfig of the runner
        # (e.g., KV sharing, encoder-only attention), but not in the
        # KVCacheConfig of the scheduler.
        self.runner_only_attn_layers: set[str] = set()

351
352
353
354
        # Cached outputs.
        self._draft_token_ids: Optional[Union[list[list[int]],
                                              torch.Tensor]] = None

355
356
357
358
    def _init_model_kwargs(self, num_tokens: int):
        model_kwargs = dict[str, Any]()
        num_reqs = self.input_batch.num_reqs

359
        num_pooling_reqs = len(self.input_batch.pooling_params)
360
361
362
363

        if num_pooling_reqs == 0:
            return model_kwargs

364
        # This does nontrivial work.
365
366
        pooling_params = self.input_batch.pooling_metadata.pooling_params

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        assert num_pooling_reqs == num_reqs

        token_type_id_requests = dict[int, Any]()
        for i, param in enumerate(pooling_params):
            if param.extra_kwargs is not None and \
            (token_types := param.extra_kwargs.get(
                "compressed_token_type_ids")) is not None:
                token_type_id_requests[i] = token_types

        if len(token_type_id_requests) == 0:
            return model_kwargs

        seq_lens = self.seq_lens[:num_reqs]
        token_type_ids = []

        for i in range(num_reqs):
            pos = token_type_id_requests.get(i, seq_lens[i])
            ids = (torch.arange(seq_lens[i]) >= pos).int()
            token_type_ids.append(ids)

        model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
            device=self.device)
        return model_kwargs

391
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
392
393
        """
        Update the order of requests in the batch based on the attention
394
        backend's needs. For example, some attention backends (namely MLA) may
395
396
397
398
399
400
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            scheduler_output: The scheduler output.
        """
401
402
403
404
405
406
407
408
        # Attention free models have zero kv_cache_goups, however models
        # like Mamba are also attention free but use the kv_cache for
        # keeping its internal state. This is why we check the number
        # of kv_cache groups instead of solely checking
        # for self.model_config.is_attention_free.
        if len(self.kv_cache_config.kv_cache_groups) == 0:
            return

409
410
411
412
413
        if self.reorder_batch_threshold is not None:
            reorder_batch_to_split_decodes_and_prefills(
                self.input_batch,
                scheduler_output,
                decode_threshold=self.reorder_batch_threshold)
414

415
416
417
418
419
420
421
422
423
424
425
    # Note: used for model runner override.
    def _init_device_properties(self) -> None:
        """Initialize attributes from torch.cuda.get_device_properties
        """
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

    # Note: used for model runner override.
    def _sync_device(self) -> None:
        torch.cuda.synchronize()

426
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
427
428
429
430
431
432
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

433
434
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
435
436
        """
        # Remove finished requests from the cached states.
437
438
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
439
440
441
442
443
444
445
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        for req_id in scheduler_output.finished_req_ids:
446
            self.input_batch.remove_request(req_id)
447
448

        # Free the cached encoder outputs.
449
450
        for mm_hash in scheduler_output.free_encoder_mm_hashes:
            self.encoder_cache.pop(mm_hash, None)
451

452
453
454
455
456
457
458
459
460
461
462
463
464
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
465
            self.input_batch.remove_request(req_id)
466

467
        reqs_to_add: list[CachedRequestState] = []
468
        # Add new requests to the cached states.
469
470
471
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
472
            pooling_params = new_req_data.pooling_params
473

474
475
            if sampling_params and \
                sampling_params.sampling_type == SamplingType.RANDOM_SEED:
476
477
478
479
480
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

481
            if pooling_params:
482
483
                task = pooling_params.task
                assert task is not None, "You did not set `task` in the API"
484

485
                model = cast(VllmModelForPooling, self.get_model())
486
                to_update = model.pooler.get_pooling_updates(task)
487
488
                to_update.apply(pooling_params)

489
            req_state = CachedRequestState(
490
                req_id=req_id,
491
                prompt_token_ids=new_req_data.prompt_token_ids,
492
                mm_kwargs=new_req_data.mm_kwargs,
493
                mm_positions=new_req_data.mm_positions,
494
                mm_hashes=new_req_data.mm_hashes,
495
                sampling_params=sampling_params,
496
                pooling_params=pooling_params,
497
                generator=generator,
498
499
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
500
                output_token_ids=[],
501
                lora_request=new_req_data.lora_request,
502
            )
503
504
            self.requests[req_id] = req_state

505
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
506
            if self.uses_mrope:
507
                self._init_mrope_positions(req_state)
508

509
            reqs_to_add.append(req_state)
510

511
        # Update the states of the running/resumed requests.
512
        is_last_rank = get_pp_group().is_last_rank
513
514
        req_data = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(req_data.req_ids):
515
            req_state = self.requests[req_id]
516
517
518
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
            resumed_from_preemption = req_data.resumed_from_preemption[i]
519

520
            # Update the cached states.
521
            req_state.num_computed_tokens = num_computed_tokens
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538

            if not is_last_rank:
                # When using PP, the scheduler sends the sampled tokens back,
                # because there's no direct communication between the first-
                # stage worker and the last-stage worker.
                new_token_ids = req_data.new_token_ids[i]
                # Add the sampled token(s) from the previous step (if any).
                # This doesn't include "unverified" tokens like spec tokens.
                num_new_tokens = (num_computed_tokens + len(new_token_ids) -
                                  req_state.num_tokens)
                if num_new_tokens == 1:
                    # Avoid slicing list in most common case.
                    req_state.output_token_ids.append(new_token_ids[-1])
                elif num_new_tokens > 0:
                    req_state.output_token_ids.extend(
                        new_token_ids[-num_new_tokens:])

539
            # Update the block IDs.
540
            if not resumed_from_preemption:
541
542
543
544
545
                if new_block_ids is not None:
                    # Append the new blocks to the existing block IDs.
                    for block_ids, new_ids in zip(req_state.block_ids,
                                                  new_block_ids):
                        block_ids.extend(new_ids)
546
            else:
547
                assert new_block_ids is not None
548
549
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
550
                req_state.block_ids = new_block_ids
551
552
553
554
555
556

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
557
                reqs_to_add.append(req_state)
558
559
560
561
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
562
                num_computed_tokens)
563
564
565
            if new_block_ids is not None:
                self.input_batch.block_table.append_row(
                    new_block_ids, req_index)
566
567
568
569
570
571
572

            # For the last rank, we don't need to update the token_ids_cpu
            # because the sampled tokens are already cached.
            if not is_last_rank:
                # Add new_token_ids to token_ids_cpu.
                start_token_index = num_computed_tokens
                end_token_index = num_computed_tokens + len(new_token_ids)
573
                self.input_batch.token_ids_cpu[
574
575
576
577
578
                    req_index,
                    start_token_index:end_token_index] = new_token_ids
                self.input_batch.num_tokens_no_spec[
                    req_index] = end_token_index
                self.input_batch.num_tokens[req_index] = end_token_index
579

580
581
582
583
584
585
586
587
588
589
590
591
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = (
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
            if spec_token_ids:
                num_spec_tokens = len(spec_token_ids)
                start_index = self.input_batch.num_tokens_no_spec[req_index]
                end_token_index = start_index + num_spec_tokens
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
                # NOTE(woosuk): `num_tokens` here may include spec tokens.
                self.input_batch.num_tokens[req_index] += num_spec_tokens

592
593
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
594
595
        for request in reqs_to_add:
            self.input_batch.add_request(request)
596

597
598
599
600
601
602
        # Condense the batched states if there are gaps left by removed requests
        self.input_batch.condense()
        # Allow attention backend to reorder the batch, potentially
        self._may_reorder_batch(scheduler_output)
        # Refresh batch metadata with any pending updates.
        self.input_batch.refresh_metadata()
603

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    def _init_mrope_positions(self, req_state: CachedRequestState):
        image_grid_thw = []
        video_grid_thw = []
        second_per_grid_ts = []
        audio_feature_lengths = []
        use_audio_in_video = False
        for mm_item in req_state.mm_kwargs:
            mm_input = mm_item.get_data()
            if (t := mm_input.get("image_grid_thw")) is not None:
                image_grid_thw.append(t.tolist())
            if (t := mm_input.get("video_grid_thw")) is not None:
                video_grid_thw.append(t.tolist())
            if (t := mm_input.get("second_per_grid_ts")) is not None:
                second_per_grid_ts.append(t)
            if (t := mm_input.get("audio_feature_lengths")) is not None:
                audio_feature_lengths.append(t)
            if mm_input.get("use_audio_in_video") is True:
                use_audio_in_video = True

        req_state.mrope_positions, req_state.mrope_position_delta = \
            MRotaryEmbedding.get_input_positions_tensor(
                req_state.prompt_token_ids,
                hf_config=self.model_config.hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )

634
    def _extract_mm_kwargs(
635
        self,
636
637
        scheduler_output: "SchedulerOutput",
    ) -> BatchedTensorInputs:
638
639
        if not self.is_multimodal_raw_input_supported or not scheduler_output:  # noqa: SIM102
            return {}
640

641
642
643
        mm_kwargs = list[MultiModalKwargsItem]()
        for req in scheduler_output.scheduled_new_reqs:
            mm_kwargs.extend(req.mm_kwargs)
644

645
646
647
648
649
650
651
652
        # Input all modalities at once
        mm_kwargs_combined: BatchedTensorInputs = {}
        for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
                mm_kwargs,
                device=self.device,
                pin_memory=self.pin_memory,
        ):
            mm_kwargs_combined.update(mm_kwargs_group)
653

654
        return mm_kwargs_combined
655

656
657
658
659
660
661
662
663
    def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
        if not self.is_multimodal_raw_input_supported:
            return {}
        mm_budget = self.mm_budget
        assert mm_budget is not None

        dummy_modality = mm_budget.get_modality_with_max_tokens()
        return self._get_mm_dummy_batch(dummy_modality, num_seqs)
664

665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
    def _get_cumsum_and_arange(
        self,
        num_tokens: np.ndarray,
        cumsum_dtype: Optional[np.dtype] = None,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Get the cumulative sum and batched arange of the given array.
        # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_tokens])
        """
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
        total_num_tokens = cu_num_tokens[-1]
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_tokens] - cumsums_offsets

        return cu_num_tokens, arange

685
    def _prepare_inputs(
686
687
        self,
        scheduler_output: "SchedulerOutput",
688
689
    ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
               np.ndarray, Optional[CommonAttentionMetadata], int]:
690
691
692
693
694
695
        """
        :return: tuple[
            attn_metadata: layer-to-attention_metadata mapping,
            logits_indices, spec_decode_metadata
        ]
        """
696
697
698
699
700
701
702
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
703
        self.input_batch.block_table.commit_block_table(num_reqs)
704
705

        # Get the number of scheduled tokens for each request.
706
707
708
709
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
710
711
712

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
713
714
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
715

716
717
718
719
        # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
        # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        cu_num_tokens, arange = self._get_cumsum_and_arange(
            num_scheduled_tokens)
720
721

        # Get positions.
722
        positions_np = self.positions_np[:total_num_scheduled_tokens]
723
724
725
726
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

727
728
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
729
        if self.uses_mrope:
730
731
            self._calc_mrope_positions(scheduler_output)

732
733
734
735
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
736
737
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
738

739
740
741
742
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
743
                           0,
744
745
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
746

747
748
749
750
        self.input_batch.block_table.compute_slot_mapping(
            req_indices, positions_np)
        self.input_batch.block_table.commit_slot_mapping(
            total_num_scheduled_tokens)
751
752

        # Prepare the attention metadata.
753
        self.query_start_loc_np[0] = 0
754
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
755
756
757
758
759
        # Note: pad query_start_loc to be non-decreasing, as kernels
        # like FlashAttention requires that
        self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1])
        self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
        query_start_loc = self.query_start_loc[:num_reqs + 1]
760

761
762
763
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
764
765
766
767
        # Fill unused with 0 for full cuda graph mode.
        self.seq_lens_np[num_reqs:].fill(0)
        self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
        seq_lens = self.seq_lens[:num_reqs]
768
        max_seq_len = self.seq_lens_np[:num_reqs].max().item()
769
770
771
772

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
773
        if self.uses_mrope:
774
775
776
777
778
779
780
781
782
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
783

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
        if not use_spec_decode:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

            spec_decode_metadata = self._calc_spec_decode_metadata(
                num_draft_tokens, cu_num_tokens)
            logits_indices = spec_decode_metadata.logits_indices

        logits_indices_padded = None
        if self.cache_config.kv_sharing_fast_prefill:
            assert self.kv_sharing_fast_prefill_logits_indices is not None
            num_logits = logits_indices.shape[0]
            assert num_logits > 0
            self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
                logits_indices)
            # There might have leftover indices in logits_indices[num_logits:]
            # from previous iterations, whose values may be greater than the
            # batch size in the current iteration. To ensure indices are always
            # valid, we fill the padded indices with the last index.
            self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
                logits_indices[-1].item())
821
            if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
822
823
824
825
826
827
828
829
830
831
832
                    and num_logits <= self.cudagraph_batch_sizes[-1]):
                # Use piecewise CUDA graphs.
                # Add padding to the batch size.
                num_logits_padded = self.vllm_config.pad_for_cudagraph(
                    num_logits)
            else:
                num_logits_padded = num_logits
            logits_indices_padded = (
                self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
            )

833
        attn_metadata: dict[str, Any] = {}
834

835
836
837
838
839
840
841
        # Used in the below loop.
        query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
        seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
        num_computed_tokens_cpu = (
            self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
        spec_decode_common_attn_metadata = None

842
843
844
845
846
        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):

847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            if isinstance(kv_cache_group_spec.kv_cache_spec,
                          EncoderOnlyAttentionSpec):
                # Encoder-only layers do not have KV cache, so we need to
                # create a dummy block table and slot mapping for them.
                blk_table_tensor = torch.zeros(
                    (num_reqs, 1),
                    dtype=torch.int32,
                    pin_memory=self.pin_memory,
                    device="cpu").to(self.device, non_blocking=True)
                slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
                                           dtype=torch.int32,
                                           pin_memory=self.pin_memory,
                                           device="cpu").to(self.device,
                                                            non_blocking=True)
                num_common_prefix_blocks = 0
            else:
                blk_table = self.input_batch.block_table[kv_cache_group_id]
                blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
                slot_mapping = blk_table.slot_mapping[:
                                                      total_num_scheduled_tokens]

                # Fill unused with -1. Needed for reshape_and_cache in full cuda
                # graph mode.
                blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
                num_common_prefix_blocks = (
                    scheduler_output.
                    num_common_prefix_blocks[kv_cache_group_id])
874

875
            common_attn_metadata = CommonAttentionMetadata(
876
877
878
879
880
                query_start_loc=query_start_loc,
                query_start_loc_cpu=query_start_loc_cpu,
                seq_lens=seq_lens,
                seq_lens_cpu=seq_lens_cpu,
                num_computed_tokens_cpu=num_computed_tokens_cpu,
881
882
883
                num_reqs=num_reqs,
                num_actual_tokens=total_num_scheduled_tokens,
                max_query_len=max_num_scheduled_tokens,
884
                max_seq_len=max_seq_len,
885
886
                block_table_tensor=blk_table_tensor,
                slot_mapping=slot_mapping,
887
                causal=True,
888
889
890
891
892
893
            )

            if self.speculative_config and \
                spec_decode_common_attn_metadata is None:
                spec_decode_common_attn_metadata = common_attn_metadata

894
895
896
897
898
899
900
            for attn_group in self.attn_groups[kv_cache_group_id]:
                # Prepare for cascade attention if enabled & beneficial.
                common_prefix_len = 0
                builder = attn_group.metadata_builder
                if self.cascade_attn_enabled:
                    common_prefix_len = self._compute_cascade_attn_prefix_len(
                        num_scheduled_tokens,
901
                        num_common_prefix_blocks,
902
903
904
                        kv_cache_group_spec.kv_cache_spec,
                        builder,
                    )
905

906
907
908
                attn_metadata_i = (builder.build(
                    common_prefix_len=common_prefix_len,
                    common_attn_metadata=common_attn_metadata,
909
910
                ))

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
                fast_prefill_metadata = attn_metadata_i
                if (self.cache_config.kv_sharing_fast_prefill
                        and self.kv_sharing_fast_prefill_eligible_layers):
                    # Dynamically create a a dataclass type that inherits
                    # from attention metadata type but includes additional
                    # fields logits_indices_padded and num_logits_indices
                    # which are required for prefill truncation
                    fast_prefill_metadata_type = (
                        make_kv_sharing_fast_prefill_attention_metadata(
                            metadata_cls=type(attn_metadata_i), ))
                    fast_prefill_metadata = fast_prefill_metadata_type(
                        **dataclasses.asdict(attn_metadata_i),
                        logits_indices_padded=logits_indices_padded,
                        num_logits_indices=logits_indices.size(0),
                    )

                for layer_name in attn_group.layer_names:
                    if (self.cache_config.kv_sharing_fast_prefill
                            and layer_name
                            in self.kv_sharing_fast_prefill_eligible_layers):
                        attn_metadata[layer_name] = fast_prefill_metadata
                        continue
                    attn_metadata[layer_name] = attn_metadata_i
934

935
936
937
938
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

939
940
941
        return (attn_metadata, logits_indices, spec_decode_metadata,
                num_scheduled_tokens, spec_decode_common_attn_metadata,
                max_num_scheduled_tokens)
942

943
944
945
946
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
947
948
        kv_cache_spec: KVCacheSpec,
        attn_metadata_builder: AttentionMetadataBuilder,
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
967
        common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
1005
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
1016
1017
1018
1019
1020
        common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
                             kv_cache_spec.block_size)
        use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
                              (isinstance(kv_cache_spec, FullAttentionSpec)
                               and kv_cache_spec.sliding_window is not None))
1021
1022
1023
1024
        use_local_attention = (
            isinstance(kv_cache_spec, ChunkedLocalAttentionSpec)
            or (isinstance(kv_cache_spec, FullAttentionSpec)
                and kv_cache_spec.attention_chunk_size is not None))
1025
1026
        assert isinstance(kv_cache_spec, AttentionSpec)
        use_cascade = attn_metadata_builder.use_cascade_attention(
1027
1028
1029
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
1030
            num_kv_heads=kv_cache_spec.num_kv_heads,
1031
            use_alibi=self.use_alibi,
1032
            use_sliding_window=use_sliding_window,
1033
            use_local_attention=use_local_attention,
1034
1035
1036
1037
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

1038
1039
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
1040
        for index, req_id in enumerate(self.input_batch.req_ids):
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

1078
1079
1080
1081
1082
1083
1084
                MRotaryEmbedding.get_next_input_positions_tensor(
                    out=self.mrope_positions_np,
                    out_offset=dst_start,
                    mrope_position_delta=req.mrope_position_delta,
                    context_len=num_computed_tokens + prompt_part_len,
                    num_new_tokens=completion_part_len,
                )
1085
1086
1087

                mrope_pos_ptr += completion_part_len

1088
1089
    def _calc_spec_decode_metadata(
        self,
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
1106
1107
1108
1109
1110
1111

        # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
        # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
            num_sampled_tokens, cumsum_dtype=np.int32)
        # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
1112
1113
        logits_indices = np.repeat(
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
1114
        # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
1115
1116
1117
1118
1119
1120
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
1121
1122
1123
1124
        # cu_num_draft_tokens: [3, 3, 5, 5, 6]
        # arange: [0, 1, 2, 0, 1, 0]
        cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
            num_draft_tokens, cumsum_dtype=np.int32)
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
            self.device, non_blocking=True)
        logits_indices = torch.from_numpy(logits_indices).to(self.device,
                                                             non_blocking=True)
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
            self.device, non_blocking=True)
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
1139
1140
            self.device, non_blocking=True)

1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        draft_token_ids = self.input_ids[logits_indices]
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

        metadata = SpecDecodeMetadata(
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
        return metadata

1156
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
1157
1158
1159
1160
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return
        # Batch the multi-modal inputs.
1161
        mm_kwargs = list[MultiModalKwargsItem]()
1162
1163
        # list of tuple (mm_hash, position_info)
        mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
1164
1165
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
1166
1167

            for mm_input_id in encoder_input_ids:
1168
                mm_hash = req_state.mm_hashes[mm_input_id]
1169
                mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
1170
1171
                mm_hashes_pos.append(
                    (mm_hash, req_state.mm_positions[mm_input_id]))
1172
1173
1174
1175
1176
1177
1178
1179
1180

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        encoder_outputs = []
1181
1182
        for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
                mm_kwargs,
1183
                device=self.device,
1184
1185
                pin_memory=self.pin_memory,
        ):
1186
1187
1188
1189
1190
1191
1192
1193
            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
1194
                **mm_kwargs_group)
1195

1196
1197
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
1198
                expected_num_items=num_items,
1199
1200
            )

1201
1202
            for output in curr_group_outputs:
                encoder_outputs.append(output)
1203

1204
1205
1206
        # Cache the encoder outputs by mm_hash
        for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
            self.encoder_cache[mm_hash] = scatter_mm_placeholders(
1207
1208
1209
1210
1211
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
1212
1213
        self,
        scheduler_output: "SchedulerOutput",
1214
        shift_computed_tokens: int = 0,
1215
    ) -> list[torch.Tensor]:
1216
        mm_embeds: list[torch.Tensor] = []
1217
        for req_id in self.input_batch.req_ids:
1218
1219
1220
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
1221
1222
            num_computed_tokens = \
                req_state.num_computed_tokens + shift_computed_tokens
1223
            mm_positions = req_state.mm_positions
1224
            mm_hashes = req_state.mm_hashes
1225
            for i, pos_info in enumerate(mm_positions):
1226
1227
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
1244
1245
                    num_encoder_tokens,
                )
1246
                assert start_idx < end_idx
1247
1248
1249
1250
1251

                mm_hash = mm_hashes[i]
                encoder_output = self.encoder_cache.get(mm_hash, None)
                assert encoder_output is not None,\
                    f"Encoder cache miss for {mm_hash}."
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
1262

1263
    def get_model(self) -> nn.Module:
1264
1265
1266
        # get raw model out of the cudagraph wrapper.
        if isinstance(self.model, CUDAGraphWrapper):
            return self.model.unwrap()
1267
1268
        return self.model

1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
    def get_supported_generation_tasks(self) -> list[GenerationTask]:
        model = self.get_model()
        supported_tasks = list[GenerationTask]()

        if is_text_generation_model(model):
            supported_tasks.append("generate")

        if supports_transcription(model):
            if model.supports_transcription_only:
                return ["transcription"]

            supported_tasks.append("transcription")

        return supported_tasks

1284
1285
1286
1287
1288
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
        supported_tasks = list(model.pooler.get_supported_tasks())

        if (self.scheduler_config.chunked_prefill_enabled
                and "encode" in supported_tasks):
            supported_tasks.remove("encode")

            logger.info_once("Chunked prefill is not supported with "
                             "encode task which using ALL pooling. "
                             "Please turn off chunked prefill by "
                             "`--no-enable-chunked-prefill` before using it.")

        return supported_tasks
1301

1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks = list[SupportedTask]()

        if self.model_config.runner_type == "generate":
            tasks.extend(self.get_supported_generation_tasks())
        if self.model_config.runner_type == "pooling":
            tasks.extend(self.get_supported_pooling_tasks())

        return tuple(tasks)

1312
1313
1314
1315
1316
1317
1318
1319
1320
    def apply_grammar_bitmask(
        self,
        scheduler_output: "SchedulerOutput",
        logits: torch.Tensor,
    ):
        grammar_bitmask = scheduler_output.grammar_bitmask
        if grammar_bitmask is None:
            return

1321
1322
1323
1324
1325
1326
1327
1328
1329
        # We receive the structured output bitmask from the scheduler,
        # compacted to contain bitmasks only for structured output requests.
        # The order of the requests in the bitmask is not guaranteed to be the
        # same as the order of the requests in the gpu runner's batch. We need
        # to sort the bitmask to match the order of the requests used here.

        # Get the batch indices of the structured output requests.
        # Keep track of the number of speculative tokens scheduled for every
        # request in the batch, as the logit indices are offset by this amount.
1330
        struct_out_req_batch_indices: dict[str, int] = {}
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
        cumulative_offset = 0
        seq = sorted(self.input_batch.req_id_to_index.items(),
                     key=lambda x: x[1])
        for req_id, batch_index in seq:
            logit_index = batch_index + cumulative_offset
            cumulative_offset += len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            if req_id in scheduler_output.structured_output_request_ids:
                struct_out_req_batch_indices[req_id] = logit_index

        out_indices = []

        # Reorder the bitmask to match the order of the requests in the batch.
1344
1345
1346
1347
        sorted_bitmask = np.full(shape=(logits.shape[0],
                                        grammar_bitmask.shape[1]),
                                 fill_value=-1,
                                 dtype=grammar_bitmask.dtype)
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
        cumulative_index = 0
        seq = sorted(scheduler_output.structured_output_request_ids.items(),
                     key=lambda x: x[1])
        for req_id, _ in seq:
            logit_index = struct_out_req_batch_indices[req_id]
            num_spec_tokens = len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            for i in range(1 + num_spec_tokens):
                sorted_bitmask[logit_index + i] = \
                    grammar_bitmask[cumulative_index + i]
                out_indices.append(logit_index + i)
            cumulative_index += 1 + num_spec_tokens
        grammar_bitmask = sorted_bitmask
1361

1362
        # If the length of out indices and the logits have the same shape
1363
1364
        # we don't need to pass indices to the kernel,
        # since the bitmask is already aligned with the logits.
1365
        skip_out_indices = len(out_indices) == logits.shape[0]
1366

1367
1368
        # Serialization of np.ndarray is much more efficient than a tensor,
        # so we receive it in that format.
1369
        grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
1370

1371
1372
1373
1374
        # Force use of the torch.compile implementation from xgrammar to work
        # around issues with the Triton kernel in concurrent structured output
        # scenarios. See PR #19565 and issues #19493, #18376 for details.
        xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
1375
1376
            logits,
            grammar_bitmask.to(self.device, non_blocking=True),
1377
            indices=out_indices if not skip_out_indices else None,
1378
1379
        )

1380
1381
1382
1383
1384
1385
1386
    def sync_and_slice_intermediate_tensors(
            self, num_tokens: int, intermediate_tensors: IntermediateTensors,
            sync_self: bool) -> IntermediateTensors:

        assert self.intermediate_tensors is not None

        tp = self.vllm_config.parallel_config.tensor_parallel_size
1387
        enabled_sp = self.compilation_config.pass_config. \
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
            enable_sequence_parallelism
        if enabled_sp:
            # When sequence parallelism is enabled, we always pad num_tokens
            # to be a multiple of tensor_parallel_size (tp) earlier
            assert num_tokens % tp == 0
        is_residual_scattered = tp > 1 and enabled_sp \
            and num_tokens % tp == 0

        # When sequence parallelism is enabled, the "residual" tensor is sharded
        # across tensor parallel ranks, so each rank only needs its own slice.
        if sync_self:
            assert intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
1401
                is_scattered = k == "residual" and is_residual_scattered
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
                copy_len = num_tokens // tp if is_scattered else \
                    num_tokens
                self.intermediate_tensors[k][:copy_len].copy_(
                    v[:copy_len], non_blocking=True)

        return IntermediateTensors({
            k:
            v[:num_tokens // tp]
            if k == "residual" and is_residual_scattered else v[:num_tokens]
            for k, v in self.intermediate_tensors.items()
        })

1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
    def eplb_step(self,
                  is_dummy: bool = False,
                  is_profile: bool = False) -> None:
        """
        Step for the EPLB (Expert Parallelism Load Balancing) state.
        """
        if not self.parallel_config.enable_eplb:
            return

        assert self.eplb_state is not None
1424
1425
        model = self.get_model()
        assert is_mixture_of_experts(model)
1426
        self.eplb_state.step(
1427
            model,
1428
1429
            is_dummy,
            is_profile,
1430
            log_stats=self.parallel_config.eplb_config.log_balancedness,
1431
1432
        )

1433
1434
    def get_dp_padding(self,
                       num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
1435
1436
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1437
1438
1439
1440
1441
1442
1443
1444
1445

        # For DP: Don't pad when setting enforce_eager.
        # This lets us set enforce_eager on the prefiller in a P/D setup and
        # still use CUDA graphs (enabled by this padding) on the decoder.
        #
        # TODO(tms) : There are many cases where padding is enabled for
        # prefills, causing unnecessary and excessive padding of activations.

        if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
1446
            # Early exit.
1447
            return 0, None
1448
1449
1450
1451

        num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
            num_tokens, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
1452
1453
1454
1455
1456
        num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
                                                dp_size,
                                                device="cpu",
                                                dtype=torch.int32)
        return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
1457

1458
1459
1460
1461
1462
    def _pool(
        self,
        hidden_states: torch.Tensor,
        num_scheduled_tokens: int,
        num_scheduled_tokens_np: np.ndarray,
1463
        kv_connector_output: Optional[KVConnectorOutput],
1464
1465
1466
1467
1468
1469
    ) -> ModelRunnerOutput:
        assert self.input_batch.num_reqs ==\
            len(self.input_batch.pooling_params), \
        "Either all or none of the requests in" \
        " a batch must be pooling request"

1470
        hidden_states = hidden_states[:num_scheduled_tokens]
1471
        pooling_metadata = self.input_batch.pooling_metadata
1472
1473
1474
        pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
                                              device=hidden_states.device)
        seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
1475

1476
        # Pooling models D2H & synchronize occurs in pooler.py:build_output
1477
        raw_pooler_output = self.model.pooler(
1478
            hidden_states=hidden_states, pooling_metadata=pooling_metadata)
1479
1480
1481

        pooler_output: list[Optional[torch.Tensor]] = []
        for raw_output, seq_len, prompt_len in zip(
1482
                raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
1483

1484
1485
            output = raw_output.data if seq_len == prompt_len else None
            pooler_output.append(output)
1486
1487
1488
1489
1490
1491
1492
1493

        return ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=[],
            logprobs=None,
            prompt_logprobs_dict={},
            pooler_output=pooler_output,
1494
            kv_connector_output=kv_connector_output,
1495
1496
        )

1497
1498
1499
1500
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
1501
        intermediate_tensors: Optional[IntermediateTensors] = None,
1502
    ) -> Union[ModelRunnerOutput, IntermediateTensors]:
1503
        self._update_states(scheduler_output)
1504
        if not scheduler_output.total_num_scheduled_tokens:
1505
1506
1507
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT
Robert Shaw's avatar
Robert Shaw committed
1508

1509
1510
            return self.kv_connector_no_forward(scheduler_output,
                                                self.vllm_config)
1511
1512

        # Prepare the decoder inputs.
1513
1514
        (attn_metadata, logits_indices, spec_decode_metadata,
         num_scheduled_tokens_np, spec_decode_common_attn_metadata,
1515
         max_query_len) = self._prepare_inputs(scheduler_output)
1516

1517
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
1518
        if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
1519
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
1520
            # Use CUDA graphs.
1521
            # Add padding to the batch size.
1522
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
1523
1524
1525
                num_scheduled_tokens)
        else:
            # Eager mode.
1526
1527
1528
            # Pad tokens to multiple of tensor_parallel_size when
            # enabled collective fusion for SP
            tp_size = self.vllm_config.parallel_config.tensor_parallel_size
1529
            if self.compilation_config.pass_config. \
1530
1531
1532
1533
                enable_sequence_parallelism and tp_size > 1:
                num_input_tokens = round_up(num_scheduled_tokens, tp_size)
            else:
                num_input_tokens = num_scheduled_tokens
1534

1535
        # Padding for DP
1536
1537
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        num_input_tokens += num_pad
1538

1539
1540
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
1541
        if self.supports_mm_inputs:
1542
1543
1544
1545
1546
1547
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

1548
        if self.supports_mm_inputs and get_pp_group().is_first_rank:
1549
1550
1551
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
1552
1553
            inputs_embeds_scheduled = self.model.get_input_embeddings(
                input_ids=self.input_ids[:num_scheduled_tokens],
1554
1555
                multimodal_embeddings=mm_embeds or None,
            )
1556

1557
            # TODO(woosuk): Avoid the copy. Optimize.
1558
1559
1560
            self.inputs_embeds[:num_scheduled_tokens].copy_(
                inputs_embeds_scheduled)

1561
            input_ids = None
1562
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
1563
1564
1565
1566
            model_kwargs = {
                **self._init_model_kwargs(num_scheduled_tokens),
                **self._extract_mm_kwargs(scheduler_output),
            }
1567
        else:
1568
1569
1570
1571
1572
1573
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
1574
            model_kwargs = self._init_model_kwargs(num_input_tokens)
1575
1576
1577
1578
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
1579

1580
1581
1582
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1583
1584
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                num_input_tokens, intermediate_tensors, True)
1585

1586
1587
1588
1589
1590
1591
        uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
            num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
        batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
                                           uniform_decode=uniform_decode)
        cudagraph_runtime_mode, batch_descriptor = \
            self.cudagraph_dispatcher.dispatch(batch_descriptor)
1592

1593
        # Run the model.
1594
        # Use persistent buffers for CUDA graphs.
1595
1596
1597
1598
1599
        with set_forward_context(
                attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
1600
1601
                cudagraph_runtime_mode=cudagraph_runtime_mode,
                batch_descriptor=batch_descriptor,
1602
1603
        ), self.maybe_get_kv_connector_output(
                scheduler_output) as kv_connector_output:
1604

Robert Shaw's avatar
Robert Shaw committed
1605
            model_output = self.model(
1606
                input_ids=input_ids,
1607
                positions=positions,
1608
                intermediate_tensors=intermediate_tensors,
1609
                inputs_embeds=inputs_embeds,
1610
                **model_kwargs,
1611
            )
1612
1613

        if self.use_aux_hidden_state_outputs:
Robert Shaw's avatar
Robert Shaw committed
1614
            hidden_states, aux_hidden_states = model_output
1615
        else:
Robert Shaw's avatar
Robert Shaw committed
1616
            hidden_states = model_output
1617
1618
            aux_hidden_states = None

1619
1620
1621
1622
1623
1624
1625
        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        broadcast_pp_output = \
            self.parallel_config.distributed_executor_backend \
            == "external_launcher" and len(get_pp_group().ranks) > 0
1626
        if not get_pp_group().is_last_rank:
1627
            # For mid-pipeline stages, return the hidden states.
1628
            assert isinstance(hidden_states, IntermediateTensors)
1629
            if not broadcast_pp_output:
1630
                hidden_states.kv_connector_output = kv_connector_output
1631
1632
1633
1634
1635
                return hidden_states
            get_pp_group().send_tensor_dict(hidden_states.tensors,
                                            all_gather_group=get_tp_group())
            logits = None
        else:
1636
1637
            if self.input_batch.pooling_params:
                return self._pool(hidden_states, num_scheduled_tokens,
1638
                                  num_scheduled_tokens_np, kv_connector_output)
1639

1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states, None)
        if broadcast_pp_output:
            model_output_broadcast_data = {
                "logits": logits.contiguous(),
            } if logits is not None else {}
            model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
                model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
            assert model_output_broadcast_data is not None
            logits = model_output_broadcast_data["logits"]
1650

1651
1652
1653
1654
        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

1655
        # Sample the next token and get logprobs if needed.
1656
        sampling_metadata = self.input_batch.sampling_metadata
1657
        if spec_decode_metadata is None:
1658
            sampler_output = self.sampler(
1659
1660
1661
1662
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
1663
1664
1665
1666
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
1667
            assert logits is not None
1668
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
1669
            sampler_output = self.sampler(
1670
                logits=bonus_logits,
1671
1672
1673
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids
1674

1675
1676
1677
            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
1678
            target_logits = logits[spec_decode_metadata.target_logits_indices]
1679
            output_token_ids = self.rejection_sampler(
1680
                spec_decode_metadata,
1681
                None,  # draft_probs
1682
                target_logits,
1683
                bonus_token_ids,
1684
1685
                sampling_metadata,
            )
1686
            sampler_output.sampled_token_ids = output_token_ids
1687

1688
1689
1690
1691
        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

1692
1693
        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
1694
1695
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
1696
1697
1698
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
1699
            if seq_len < req_state.num_tokens:
1700
                # Ignore the sampled token for partial prefills.
1701
                # Rewind the generator state as if the token was not sampled.
1702
                # This relies on cuda-specific torch-internal impl details
1703
1704
1705
1706
1707
1708
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)
1709

1710
1711
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
1712
1713
1714
1715
1716
1717
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
1718
            hidden_states[:num_scheduled_tokens],
1719
            scheduler_output.num_scheduled_tokens,
1720
1721
        )

1722
        # Get the valid generated tokens.
1723
1724
1725
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
1726
            # No spec decode tokens.
1727
1728
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
1729
            # Includes spec decode tokens.
1730
            valid_sampled_token_ids = self.rejection_sampler.parse_output(
1731
1732
1733
                sampled_token_ids,
                self.input_batch.vocab_size,
            )
1734
1735
1736
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()
1737

1738
1739
1740
1741
1742
        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
1743
        req_ids = self.input_batch.req_ids
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
        for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
            if not sampled_ids:
                continue

            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
            end_idx = start_idx + len(sampled_ids)
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
                f"{self.max_model_len}")

            self.input_batch.token_ids_cpu[req_idx,
                                           start_idx:end_idx] = sampled_ids
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
            self.input_batch.num_tokens[req_idx] = end_idx
1759
            req_id = req_ids[req_idx]
1760
1761
1762
            req_state = self.requests[req_id]
            req_state.output_token_ids.extend(sampled_ids)

1763
        if self.speculative_config:
1764
            assert spec_decode_common_attn_metadata is not None
1765
            self._draft_token_ids = self.propose_draft_token_ids(
1766
1767
1768
1769
1770
1771
1772
                scheduler_output,
                valid_sampled_token_ids,
                sampling_metadata,
                hidden_states,
                sample_hidden_states,
                aux_hidden_states,
                spec_decode_metadata,
1773
                spec_decode_common_attn_metadata,
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
            )

        self.eplb_step()

        return ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=valid_sampled_token_ids,
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
            pooler_output=[],
1785
            kv_connector_output=kv_connector_output,
1786
1787
1788
            num_nans_in_logits=num_nans_in_logits,
        )

1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
    def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
        if self._draft_token_ids is None:
            return None
        req_ids = self.input_batch.req_ids
        if isinstance(self._draft_token_ids, torch.Tensor):
            draft_token_ids = self._draft_token_ids.tolist()
        else:
            draft_token_ids = self._draft_token_ids
        self._draft_token_ids = None
        return DraftTokenIds(req_ids, draft_token_ids)

1800
1801
1802
1803
1804
1805
1806
1807
1808
    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
        sampled_token_ids: list[list[int]],
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
        aux_hidden_states: Optional[torch.Tensor],
        spec_decode_metadata: Optional[SpecDecodeMetadata],
1809
        common_attn_metadata: CommonAttentionMetadata,
1810
    ) -> Union[list[list[int]], torch.Tensor]:
1811
1812
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if self.speculative_config.method == "ngram":
1813
            assert isinstance(self.drafter, NgramProposer)
1814
            draft_token_ids = self.propose_ngram_draft_token_ids(
1815
                sampled_token_ids)
1816
1817
        elif self.speculative_config.method == "medusa":
            assert isinstance(self.drafter, MedusaProposer)
1818
1819
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
1820
1821
1822
1823
1824
1825
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
                for num_draft, tokens in zip(
                        spec_decode_metadata.num_draft_tokens,
1826
                        sampled_token_ids):
1827
1828
                    indices.append(offset + len(tokens) - 1)
                    offset += num_draft + 1
1829
                indices = torch.tensor(indices, device=self.device)
1830
1831
                hidden_states = sample_hidden_states[indices]

1832
            draft_token_ids = self.drafter.propose(
1833
1834
1835
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
            )
1836
        elif self.speculative_config.use_eagle():
1837
1838
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
1839
            req_ids = self.input_batch.req_ids
1840
            next_token_ids: list[int] = []
1841
            for i, token_ids in enumerate(sampled_token_ids):
1842
1843
1844
1845
1846
1847
                if token_ids:
                    # Common case.
                    next_token_id = token_ids[-1]
                else:
                    # Partial prefill (rare case).
                    # Get the next token id from the request state.
1848
                    req_id = req_ids[i]
1849
1850
1851
1852
1853
                    req_state = self.requests[req_id]
                    seq_len = (req_state.num_computed_tokens +
                               scheduler_output.num_scheduled_tokens[req_id])
                    next_token_id = req_state.get_token_id(seq_len)
                next_token_ids.append(next_token_id)
1854
1855
1856
            next_token_ids = torch.tensor(next_token_ids,
                                          dtype=torch.int32,
                                          device=self.device)
Jiayi Yao's avatar
Jiayi Yao committed
1857

1858
1859
1860
            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
                target_token_ids = self.input_ids[:num_scheduled_tokens]
1861
1862
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[:num_scheduled_tokens]
1863
                if self.use_aux_hidden_state_outputs:
1864
1865
1866
                    target_hidden_states = torch.cat(
                        [h[:num_scheduled_tokens] for h in aux_hidden_states],
                        dim=-1)
1867
1868
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
1869
1870
1871
1872
            else:
                # TODO(woosuk): Refactor this.
                num_draft_tokens = spec_decode_metadata.num_draft_tokens
                num_rejected_tokens = [
1873
                    n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
1874
1875
                    for i, n in enumerate(num_draft_tokens)
                ]
1876
1877
1878
1879
1880
1881
                num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
                                                       dtype=torch.int32)
                common_attn_metadata, token_indices =\
                    self.drafter.prepare_inputs(
                    common_attn_metadata, num_rejected_tokens_cpu)

1882
                target_token_ids = self.input_ids[token_indices]
1883
1884
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[token_indices]
1885
                if self.use_aux_hidden_state_outputs:
1886
1887
                    target_hidden_states = torch.cat(
                        [h[token_indices] for h in aux_hidden_states], dim=-1)
1888
1889
                else:
                    target_hidden_states = hidden_states[token_indices]
1890
            mm_embeds = None
1891
            if self.supports_mm_inputs:
1892
1893
1894
                mm_embeds = self._gather_mm_embeddings(scheduler_output,
                                                       shift_computed_tokens=1)

1895
            draft_token_ids = self.drafter.propose(
1896
1897
1898
1899
1900
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                next_token_ids=next_token_ids,
                sampling_metadata=sampling_metadata,
1901
                common_attn_metadata=common_attn_metadata,
1902
                mm_embeds=mm_embeds,
1903
            )
1904
        return draft_token_ids
1905

1906
    def propose_ngram_draft_token_ids(
1907
        self,
1908
1909
        sampled_token_ids: list[list[int]],
    ) -> list[list[int]]:
1910
        # TODO(woosuk): Optimize.
1911
        req_ids = self.input_batch.req_ids
1912
        draft_token_ids: list[list[int]] = []
1913
1914
1915
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1916
1917
1918
1919
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

1920
1921
            # Skip requests that require sampling parameters that are not
            # supported with speculative decoding.
1922
            req_id = req_ids[i]
1923
            if req_id in self.input_batch.spec_decode_unsupported_reqs:
1924
1925
1926
                draft_token_ids.append([])
                continue

1927
1928
            num_tokens = self.input_batch.num_tokens_no_spec[i]
            if num_tokens >= self.max_model_len:
1929
1930
1931
1932
                # Skip requests that have already reached the max model length.
                draft_token_ids.append([])
                continue

1933
            drafter_output = self.drafter.propose(
1934
                self.input_batch.token_ids_cpu[i, :num_tokens])
1935
1936
1937
1938
1939
1940
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
    def update_config(self, overrides: dict[str, Any]) -> None:
        allowed_config_names = {"load_config", "model_config"}
        for config_name, config_overrides in overrides.items():
            assert config_name in allowed_config_names, \
                f"Config `{config_name}` not supported. " \
                f"Allowed configs: {allowed_config_names}"
            config = getattr(self, config_name)
            new_config = update_config(config, config_overrides)
            setattr(self, config_name, new_config)

1951
1952
1953
1954
1955
    def load_model(self, eep_scale_up: bool = False) -> None:
        """
        Args:
            eep_scale_up: the model loading is for elastic EP scale up.
        """
1956
        logger.info("Starting to load model %s...", self.model_config.model)
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
        if eep_scale_up:
            from vllm.distributed.parallel_state import get_ep_group
            num_local_physical_experts = torch.empty(1,
                                                     dtype=torch.int32,
                                                     device="cpu")
            torch.distributed.broadcast(num_local_physical_experts,
                                        group=get_ep_group().cpu_group,
                                        group_src=0)
            num_local_physical_experts = int(num_local_physical_experts.item())
            new_ep_size = get_ep_group().world_size
            global_expert_load, old_global_expert_indices = (
                EplbState.recv_state())
            num_logical_experts = global_expert_load.shape[1]
1970
            self.parallel_config.eplb_config.num_redundant_experts = (
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
                num_local_physical_experts * new_ep_size - num_logical_experts)
            assert old_global_expert_indices.shape[
                1] % num_local_physical_experts == 0
            old_ep_size = old_global_expert_indices.shape[
                1] // num_local_physical_experts
            rank_mapping = {
                old_ep_rank: old_ep_rank
                for old_ep_rank in range(old_ep_size)
            }
        else:
            global_expert_load = None
            old_global_expert_indices = None
            rank_mapping = None

1985
        with DeviceMemoryProfiler() as m:
1986
            time_before_load = time.perf_counter()
1987
            model_loader = get_model_loader(self.load_config)
1988
1989
1990
            logger.info("Loading model from scratch...")
            self.model = model_loader.load_model(
                vllm_config=self.vllm_config, model_config=self.model_config)
1991
1992
1993
1994
1995
1996
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1997
1998
1999
            if hasattr(self, "drafter"):
                logger.info("Loading drafter model...")
                self.drafter.load_model(self.model)
2000
            if self.use_aux_hidden_state_outputs:
2001
2002
2003
2004
2005
2006
2007
                if supports_eagle3(self.model):
                    self.model.set_aux_hidden_state_layers(
                        self.model.get_eagle3_aux_hidden_state_layers())
                else:
                    raise RuntimeError(
                        "Model does not support EAGLE3 interface but "
                        "aux_hidden_state_outputs was requested")
2008
            time_after_load = time.perf_counter()
2009
        self.model_memory_usage = m.consumed_memory
2010
2011
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
2012
                    time_after_load - time_before_load)
2013
        prepare_communication_buffer_for_model(self.model)
2014

2015
2016
2017
2018
2019
2020
2021
2022
        if is_mixture_of_experts(
                self.model) and self.parallel_config.enable_eplb:
            logger.info("EPLB is enabled for model %s.",
                        self.model_config.model)
            self.eplb_state = EplbState.build(
                self.model,
                self.device,
                self.parallel_config,
2023
2024
2025
                global_expert_load,
                old_global_expert_indices,
                rank_mapping,
2026
2027
            )

2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
        if (
            self.vllm_config.compilation_config.level == \
                CompilationLevel.DYNAMO_AS_IS and supports_dynamo()
        ):
            backend = self.vllm_config.compilation_config.init_backend(
                self.vllm_config)
            compilation_counter.dynamo_as_is_count += 1
            self.model.compile(
                fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
                backend=backend)
2038
2039
2040
2041
2042
2043
2044
2045
2046
            return
        # for other compilation levels, cudagraph behavior is controlled by
        # CudagraphWraper and CudagraphDispatcher of vllm.

        # wrap the model with full cudagraph wrapper if needed.
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
            self.model = CUDAGraphWrapper(self.model,
                                          self.vllm_config,
                                          runtime_mode=CUDAGraphMode.FULL)
2047

2048
2049
2050
2051
2052
    def reload_weights(self) -> None:
        assert getattr(self, "model", None) is not None, \
            "Cannot reload weights before model is loaded."
        model_loader = get_model_loader(self.load_config)
        logger.info("Reloading weights inplace...")
2053
2054
        model = self.get_model()
        model_loader.load_weights(model, model_config=self.model_config)
2055

2056
2057
2058
2059
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
2060
        model = self.get_model()
2061
        TensorizerLoader.save_model(
2062
            model,
2063
            tensorizer_config=tensorizer_config,
2064
            model_config=self.model_config,
2065
2066
        )

2067
2068
2069
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
2070
        num_scheduled_tokens: dict[str, int],
2071
    ) -> dict[str, Optional[LogprobsTensors]]:
2072
2073
2074
2075
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

2076
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
2077
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
2078
2079
2080
2081
2082

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
2083
            num_tokens = num_scheduled_tokens[req_id]
2084
2085
2086
2087
2088
2089
2090

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

2091
2092
2093
2094
2095
2096
2097
2098
2099
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
                    num_prompt_tokens - 1, num_prompt_logprobs + 1)
                in_progress_dict[req_id] = logprobs_tensors

2100
            # Determine number of logits to retrieve.
2101
2102
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
2103
            num_remaining_tokens = num_prompt_tokens - start_tok
2104
            if num_tokens <= num_remaining_tokens:
2105
                # This is a chunk, more tokens remain.
2106
2107
2108
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
2109
2110
2111
2112
2113
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
2114
2115
2116
2117
2118
2119
2120
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
2136
2137
            logprobs = self.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.sampler.gather_logprobs(
2138
2139
2140
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
2141
2142
2143
2144
2145
2146
2147
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
                token_ids, non_blocking=True)
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
                                                         non_blocking=True)
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
                ranks, non_blocking=True)
2148
2149
2150
2151
2152

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
2153
            del in_progress_dict[req_id]
2154
2155

        # Must synchronize the non-blocking GPU->CPU transfers.
2156
        if prompt_logprobs_dict:
2157
            self._sync_device()
2158
2159
2160

        return prompt_logprobs_dict

2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
    def _get_nans_in_logits(
        self,
        logits: Optional[torch.Tensor],
    ) -> dict[str, int]:
        try:
            if logits is None:
                return {req_id: 0 for req_id in self.input_batch.req_ids}

            num_nans_in_logits = {}
            num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
            for req_id in self.input_batch.req_ids:
                req_index = self.input_batch.req_id_to_index[req_id]
                num_nans_in_logits[req_id] = (
                    int(num_nans_for_index[req_index])
                    if num_nans_for_index is not None
                    and req_index < logits.shape[0] else 0)
            return num_nans_in_logits
        except IndexError:
            return {}

2181
2182
2183
2184
2185
2186
    @contextmanager
    def maybe_randomize_inputs(self, input_ids: torch.Tensor):
        """
        Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
        This is to help balance expert-selection
         - during profile_run
2187
         - during DP rank dummy run
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
        """
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
        if not randomize_inputs:
            yield
        else:
            import functools

            @functools.cache
            def rand_input_ids() -> torch.Tensor:
                return torch.randint_like(
                    self.input_ids,
                    low=0,
                    high=self.model_config.get_vocab_size(),
                    dtype=input_ids.dtype)

2204
            logger.debug_once("Randomizing dummy data for DP Rank")
2205
2206
2207
2208
2209
            input_ids.copy_(rand_input_ids()[:input_ids.size(0)],
                            non_blocking=True)
            yield
            input_ids.fill_(0)

2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
    def _get_mm_dummy_batch(
        self,
        modality: str,
        max_items_per_batch: int,
    ) -> BatchedTensorInputs:
        """Dummy data for profiling and precompiling multimodal models."""
        dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
            model_config=self.model_config,
            seq_len=self.max_num_tokens,
            mm_counts={modality: 1},
        )
        dummy_mm_data = dummy_decoder_data.multi_modal_data

        # Result in the maximum GPU consumption of the model
2224
2225
        dummy_mm_item = dummy_mm_data[modality][0]
        dummy_mm_items = [dummy_mm_item] * max_items_per_batch
2226

2227
2228
        return next(mm_kwargs_group
                    for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
2229
                        dummy_mm_items,
2230
2231
2232
                        device=self.device,
                        pin_memory=self.pin_memory,
                    ))
2233

2234
2235
2236
2237
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
2238
2239
2240
        cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
        force_attention: bool = False,
        uniform_decode: bool = False,
2241
2242
        skip_eplb: bool = False,
        is_profile: bool = False,
2243
    ) -> tuple[torch.Tensor, torch.Tensor]:
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
        """
        Run a dummy forward pass to warm up/profile run or capture the
        CUDA graph for the model.

        Args:
            num_tokens: Number of tokens to run the dummy forward pass.
            cudagraph_runtime_mode: used to control the behavior.
                - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
                - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
                - CUDAGraphMode.FULL: Full cudagraph, attention metadata is
                    needed.
            force_attention: If True, always create attention metadata. Used to 
                warm up attention backend when mode is NONE.
            uniform_decode: If True, the batch is a uniform decode batch.
            skip_eplb: If True, skip EPLB state update.
            is_profile: If True, this is a profile run.
        """
        assert cudagraph_runtime_mode in {
            CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
        }
2264

2265
        # Padding for DP
2266
2267
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
        num_tokens += num_pad
2268

2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
        # If cudagraph_mode.decode_mode() == FULL and
        # cudagraph_mode.seperate_routine(). This means that we are using
        # different graphs and/or modes for mixed prefill-decode batches vs.
        # uniform decode batches. A uniform decode batch means that all
        # requests have identical query length, except a potential virtual
        # request (shorter) in the batch account for padding.
        # Uniform decode batch could either be common pure decode, where
        # max_query_len == 1, or speculative decode, where
        # max_query_len == 1 + num_spec_decode_tokens.

        # When setting max_query_len = 1, we switch to and capture the optimized
        # routine of FA2 for pure decode, i.e., Flashdecode + an optimization
        # for GQA/MQA.
        max_query_len = self.uniform_decode_query_len if uniform_decode else \
                                                                num_tokens

2285
2286
2287
2288
2289
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
        if uniform_decode:
            num_reqs = cdiv(num_tokens, max_query_len)
            assert num_reqs <= max_num_reqs, \
                "Do not capture num_reqs > max_num_reqs for uniform batch"
            num_scheduled_tokens_list = [max_query_len] * num_reqs
            if num_tokens % max_query_len != 0:
                num_scheduled_tokens_list[-1] = num_tokens % max_query_len
        else:
            num_reqs = min(num_tokens, max_num_reqs)
            min_tokens_per_req = num_tokens // num_reqs
            num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
            num_scheduled_tokens_list[-1] += num_tokens % num_reqs

2303
2304
2305
2306
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
2307

2308
        attn_metadata: Optional[dict[str, Any]] = None
2309
2310
2311

        # If force_attention is True, we always capture attention. Otherwise,
        # it only happens for cudagraph_runtime_mode=FULL.
2312
        if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
2313
2314
            attn_metadata = {}

2315
2316
2317
            # Make sure max_model_len is used at the graph capture time.
            self.seq_lens_np[:num_reqs] = self.max_model_len
            self.seq_lens_np[num_reqs:] = 0
2318
            self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
2319

2320
2321
            for kv_cache_group_id, kv_cache_group_spec in enumerate(
                    self.kv_cache_config.kv_cache_groups):
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
                common_attn_metadata = CommonAttentionMetadata(
                    query_start_loc=self.query_start_loc[:num_reqs + 1],
                    query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
                                                                 1],
                    seq_lens=self.seq_lens[:num_reqs],
                    seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
                    num_computed_tokens_cpu=self.input_batch.
                    num_computed_tokens_cpu_tensor[:num_reqs],
                    num_reqs=num_reqs,
                    num_actual_tokens=num_tokens,
2332
                    max_query_len=max_query_len,
2333
                    max_seq_len=self.max_model_len,
2334
2335
2336
                    block_table_tensor=self.input_batch.block_table[
                        kv_cache_group_id].get_device_tensor()[:num_reqs],
                    slot_mapping=self.input_batch.
2337
2338
                    block_table[kv_cache_group_id].slot_mapping[:num_tokens],
                    causal=True)
2339

2340
2341
2342
2343
2344
                for attn_group in self.attn_groups[kv_cache_group_id]:
                    attn_metadata_i = attn_group.metadata_builder\
                        .build_for_cudagraph_capture(common_attn_metadata)
                    for layer_name in kv_cache_group_spec.layer_names:
                        attn_metadata[layer_name] = attn_metadata_i
2345

2346
2347
        with self.maybe_dummy_run_with_lora(self.lora_config,
                                            num_scheduled_tokens):
2348
            if self.supports_mm_inputs:
2349
2350
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
2351
2352
2353
2354
                model_kwargs = {
                    **self._init_model_kwargs(num_tokens),
                    **self._dummy_mm_kwargs(num_reqs),
                }
2355
2356
2357
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None
2358
                model_kwargs = self._init_model_kwargs(num_tokens)
2359

2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_tokens]
            else:
                positions = self.positions[:num_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
                            device=self.device))
2374
2375
2376

                intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                    num_tokens, None, False)
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
            if cudagraph_runtime_mode == CUDAGraphMode.NONE:
                batch_descriptor = None
            else:
                # filter out the valid batch descriptor
                _cg_mode, batch_descriptor = \
                    self.cudagraph_dispatcher.dispatch(
                        BatchDescriptor(num_tokens=num_tokens,
                                        uniform_decode=uniform_decode))
                # sanity check
                assert cudagraph_runtime_mode == _cg_mode, (
                    f"Cudagraph runtime mode mismatch at dummy_run. "
                    f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
2389

2390
            with self.maybe_randomize_inputs(input_ids), set_forward_context(
2391
2392
2393
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=num_tokens,
2394
2395
2396
                    num_tokens_across_dp=num_tokens_across_dp,
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
                    batch_descriptor=batch_descriptor):
2397
                outputs = self.model(
2398
2399
2400
2401
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
2402
                    **model_kwargs,
2403
                )
2404

2405
2406
2407
2408
            if self.use_aux_hidden_state_outputs:
                hidden_states, _ = outputs
            else:
                hidden_states = outputs
2409

2410
            if self.speculative_config and self.speculative_config.use_eagle():
2411
2412
2413
                assert isinstance(self.drafter, EagleProposer)
                self.drafter.dummy_run(num_tokens)

2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
        # This is necessary to avoid blocking DP.
        # For dummy runs, we typically skip EPLB since we don't have any real
        # requests to process.
        # However, in DP settings, there may be cases when some DP ranks do
        # not have any requests to process, so they're executing dummy batches.
        # In such cases, we still have to trigger EPLB to make sure
        # ranks execute the rearrangement in synchronization.
        if not skip_eplb:
            self.eplb_step(is_dummy=True, is_profile=is_profile)

2424
        logit_indices = np.cumsum(num_scheduled_tokens) - 1
2425
        return hidden_states, hidden_states[logit_indices]
2426
2427
2428
2429
2430
2431

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
2432
2433
2434
2435
        # The dummy hidden states may contain special values,
        # like `inf` or `nan`.
        # To avoid breaking the sampler, we use a random tensor here instead.
        hidden_states = torch.rand_like(hidden_states)
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458

        logits = self.model.compute_logits(hidden_states, None)
        num_reqs = logits.size(0)

        dummy_tensors = lambda v: torch.full(
            (num_reqs, ), v, device=self.device)

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
2459
            logitsprocs=LogitsProcessors(),
2460
        )
2461
        try:
2462
2463
            sampler_output = self.sampler(logits=logits,
                                          sampling_metadata=dummy_metadata)
2464
2465
2466
2467
2468
2469
2470
2471
2472
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
2473
        if self.speculative_config:
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
                draft_token_ids, self.device)

            num_tokens = sum(len(ids) for ids in draft_token_ids)
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
            draft_probs = None
            target_logits = torch.randn(num_tokens,
                                        logits.shape[-1],
                                        device=self.device,
                                        dtype=logits.dtype)
            # NOTE(woosuk): Here, we should use int32 because the sampler uses
            # int32 for bonus_token_ids. If the dtype mismatches, re-compilation
            # will occur at runtime.
            bonus_token_ids = torch.zeros(num_reqs,
                                          device=self.device,
                                          dtype=torch.int32)
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
                target_logits,
                bonus_token_ids,
                dummy_metadata,
            )
2500
        return sampler_output
2501

2502
    def _dummy_pooler_run_task(
2503
2504
        self,
        hidden_states: torch.Tensor,
2505
2506
        task: PoolingTask,
    ) -> PoolerOutput:
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
        num_tokens = hidden_states.shape[0]
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = min(num_tokens, max_num_reqs)
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs

        req_num_tokens = num_tokens // num_reqs

2518
        dummy_prompt_lens = torch.tensor(
2519
2520
            num_scheduled_tokens_list,
            device="cpu",
2521
2522
2523
2524
        )
        dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
                                      dtype=torch.int32,
                                      device=self.device)
2525

2526
        model = cast(VllmModelForPooling, self.get_model())
2527
2528
        dummy_pooling_params = PoolingParams(task=task)
        to_update = model.pooler.get_pooling_updates(task)
2529
2530
        to_update.apply(dummy_pooling_params)

2531
        dummy_metadata = PoolingMetadata(
2532
2533
2534
2535
            prompt_lens=dummy_prompt_lens,
            prompt_token_ids=dummy_token_ids,
            pooling_params=[dummy_pooling_params] * num_reqs,
        )
2536

2537
2538
2539
        dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
                                            device=hidden_states.device)

2540
        try:
2541
            return model.pooler(hidden_states=hidden_states,
2542
                                pooling_metadata=dummy_metadata)
2543
2544
2545
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
2546
2547
2548
                    "CUDA out of memory occurred when warming up pooler "
                    f"({task=}) with {num_reqs} dummy requests. Please try "
                    "lowering `max_num_seqs` or `gpu_memory_utilization` when "
2549
2550
2551
                    "initializing the engine.") from e
            else:
                raise e
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567

    @torch.inference_mode()
    def _dummy_pooler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> PoolerOutput:
        # Find the task that has the largest output for subsequent steps
        output_size = dict[PoolingTask, float]()
        for task in self.get_supported_pooling_tasks():
            # Run a full batch with each task to ensure none of them OOMs
            output = self._dummy_pooler_run_task(hidden_states, task)
            output_size[task] = output.get_data_nbytes()
            del output  # Allow GC

        max_task = max(output_size.items(), key=lambda x: x[1])[0]
        return self._dummy_pooler_run_task(hidden_states, max_task)
2568

2569
    def profile_run(self) -> None:
2570
        # Profile with multimodal encoder & encoder cache.
2571
        if self.supports_mm_inputs:
2572
            if self.model_config.multimodal_config.skip_mm_profiling:
2573
                logger.info(
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
                    "Skipping memory profiling for multimodal encoder and "
                    "encoder cache.")
            else:
                mm_budget = self.mm_budget
                assert mm_budget is not None

                # TODO: handle encoder-decoder models once we support them.
                if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
                    # NOTE: Currently model is profiled with a single non-text
                    # modality with the max possible input tokens even when
                    # it supports multiple.
2585
2586
2587
                    dummy_modality = mm_budget.get_modality_with_max_tokens()
                    max_mm_items_per_batch = mm_budget \
                        .max_items_per_batch_by_modality[dummy_modality]
2588
2589
2590
2591
2592
2593
2594
2595
2596

                    logger.info(
                        "Encoder cache will be initialized with a budget of "
                        "%s tokens, and profiled with %s %s items of the "
                        "maximum feature size.",
                        encoder_budget,
                        max_mm_items_per_batch,
                        dummy_modality,
                    )
2597

2598
2599
2600
2601
2602
                    # Create dummy batch of multimodal inputs.
                    batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                        dummy_modality,
                        max_mm_items_per_batch,
                    )
2603

2604
2605
2606
2607
                    # Run multimodal encoder.
                    dummy_encoder_outputs = \
                        self.model.get_multimodal_embeddings(
                        **batched_dummy_mm_inputs)
2608

2609
2610
2611
2612
                    sanity_check_mm_encoder_outputs(
                        dummy_encoder_outputs,
                        expected_num_items=max_mm_items_per_batch,
                    )
2613

2614
2615
2616
                    # Cache the dummy encoder outputs.
                    self.encoder_cache["tmp"] = dict(
                        enumerate(dummy_encoder_outputs))
2617

2618
        # Add `is_profile` here to pre-allocate communication buffers
2619
        hidden_states, last_hidden_states \
2620
            = self._dummy_run(self.max_num_tokens, is_profile=True)
2621
        if get_pp_group().is_last_rank:
2622
2623
2624
2625
            if self.is_pooling_model:
                output = self._dummy_pooler_run(hidden_states)
            else:
                output = self._dummy_sampler_run(last_hidden_states)
2626
        else:
2627
            output = None
2628
        self._sync_device()
2629
        del hidden_states, output
2630
        self.encoder_cache.clear()
2631
        gc.collect()
2632
2633

    def capture_model(self) -> None:
2634
        if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
2635
            logger.warning(
2636
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
2637
                "ensure `cudagraph_mode` was not manually set to `NONE`")
2638
            return
2639
2640
        else:
            self.initialize_cudagraph_capture()
2641

2642
2643
        compilation_counter.num_gpu_runner_capture_triggers += 1

2644
2645
2646
        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
        @contextmanager
        def freeze_gc():
            # Optimize garbage collection during CUDA graph capture.
            # Clean up, then freeze all remaining objects from being included
            # in future collections.
            gc.collect()
            should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
            if should_freeze:
                gc.freeze()
            try:
                yield
            finally:
                if should_freeze:
                    gc.unfreeze()

2662
2663
2664
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
2665
        set_cudagraph_capturing_enabled(True)
2666
        with freeze_gc(), graph_capture(device=self.device):
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
            cudagraph_mode = self.compilation_config.cudagraph_mode
            if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
                cudagraph_runtime_mode = cudagraph_mode.mixed_mode()

                compilation_cases = list(reversed(self.cudagraph_batch_sizes))
                self._capture_cudagraphs(
                    compilation_cases,
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
                    uniform_decode=False)

            # Capture full cudagraph for uniform decode batches if we have
            # dont already have full mixed prefill-decode cudagraphs
            if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
                cudagraph_mode.separate_routine():
                max_num_tokens = self.scheduler_config.max_num_seqs * \
                        self.uniform_decode_query_len
                decode_cudagraph_batch_sizes = [
                    x for x in self.cudagraph_batch_sizes if
                    x <= max_num_tokens and x >= self.uniform_decode_query_len
                ]
                compilation_cases_decode = list(
                    reversed(decode_cudagraph_batch_sizes))
                self._capture_cudagraphs(
                    compilation_cases=compilation_cases_decode,
                    cudagraph_runtime_mode=CUDAGraphMode.FULL,
                    uniform_decode=True)

        # Disable cudagraph capturing globally, so any unexpected cudagraph
        # capturing will be detected and raise an error after here.
        # Note: We don't put it into graph_capture context manager because
        # we may doing lazy capturing in future that still allows capturing
        # after here.
        set_cudagraph_capturing_enabled(False)
2700
2701
2702
2703
2704
2705
2706
2707

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
2708

2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
    def _capture_cudagraphs(self, compilation_cases: list[int],
                            cudagraph_runtime_mode: CUDAGraphMode,
                            uniform_decode: bool):
        assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
            cudagraph_runtime_mode in [CUDAGraphMode.FULL,
                                        CUDAGraphMode.PIECEWISE]

        # Only rank 0 should print progress bar during capture
        if is_global_first_rank():
            compilation_cases = tqdm(
                compilation_cases,
                disable=not self.load_config.use_tqdm_on_load,
                desc="Capturing CUDA graphs ({}, {})".format(
                    "decode" if uniform_decode else "mixed prefill-decode",
                    cudagraph_runtime_mode.name))
        # We skip EPLB here since we don't want to record dummy metrics
        for num_tokens in compilation_cases:
            for _ in range(self.compilation_config.cudagraph_num_of_warmups):
                # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
                # But be careful, warm up with `NONE`is orthogonal to
                # if we want to warm up attention or not. This is
                # different from the case where `FULL` implies capture
                # attention while `PIECEWISE` implies no attention.
                force_attention = (
                    cudagraph_runtime_mode == CUDAGraphMode.FULL)
                self._dummy_run(num_tokens,
                                cudagraph_runtime_mode=CUDAGraphMode.NONE,
                                force_attention=force_attention,
                                uniform_decode=uniform_decode,
                                skip_eplb=True)
            self._dummy_run(num_tokens,
                            cudagraph_runtime_mode=cudagraph_runtime_mode,
                            uniform_decode=uniform_decode,
                            skip_eplb=True)

2744
2745
2746
2747
    def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize the attention backends and attention metadata builders.
        """
2748
2749
2750
2751
2752
2753
        assert len(self.attn_groups) == 0, \
            "Attention backends are already initialized"

        def get_attn_backends_for_layers(
                layer_names: list[str]
        ) -> dict[type[AttentionBackend], list[str]]:
2754
2755
2756
            layers = get_layers_from_vllm_config(self.vllm_config,
                                                 AttentionLayerBase,
                                                 layer_names)
2757
2758
2759
2760
2761
2762
2763
2764
            attn_backends = {}
            attn_backend_layers = defaultdict(list)
            # Dedupe based on full class name; this is a bit safer than using
            # using the class itself as the key because when we create dynamic
            # attention backend subclasses (e.g. ChunkedLocalAttention) unless
            # they are cached correctly, there will be different objects per
            # layer.
            for layer_name in layer_names:
2765
                attn_backend = layers[layer_name].get_attn_backend()
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
                key = attn_backend.full_cls_name()
                attn_backends[key] = attn_backend
                attn_backend_layers[key].append(layer_name)
            return {
                attn_backends[k]: v
                for k, v in attn_backend_layers.items()
            }

        def create_attn_groups(
            attn_backends_map: dict[AttentionBackend, list[str]],
            kv_cache_spec: KVCacheSpec,
        ) -> list[AttentionGroup]:
            attn_groups: list[AttentionGroup] = []
            for attn_backend, layer_names in attn_backends_map.items():
                attn_metadata_builder_i = attn_backend.get_builder_cls()(
                    kv_cache_spec,
                    layer_names,
                    self.vllm_config,
                    self.device,
                )
                attn_group = AttentionGroup(attn_backend,
                                            attn_metadata_builder_i,
                                            layer_names)
                attn_groups.append(attn_group)
            return attn_groups

        for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
2793
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
2794
2795
            attn_backends = get_attn_backends_for_layers(
                kv_cache_group_spec.layer_names)
2796
2797
            self.attn_groups.append(
                create_attn_groups(attn_backends, kv_cache_spec))
2798

2799
2800
2801
        # Calculate reorder batch threshold (if neeeded)
        self.calculate_reorder_batch_threshold()

2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
    def initialize_cudagraph_capture(self) -> None:
        min_cg_support = AttentionCGSupport.ALWAYS
        min_cg_builder_name = None

        for attn_group in self._attn_group_iterator():
            builder = attn_group.metadata_builder
            if builder.cudagraph_support.value < min_cg_support.value:
                min_cg_support = builder.cudagraph_support
                min_cg_builder_name = builder.__class__.__name__

        # Flexible resolve the cudagraph mode
        cudagraph_mode = self.compilation_config.cudagraph_mode
        # check cudagraph for mixed batch is supported
        if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \
            and min_cg_support != AttentionCGSupport.ALWAYS:
            msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
                   f"with {min_cg_builder_name} backend (support: "
                   f"{min_cg_support})")
            if min_cg_support == AttentionCGSupport.NEVER:
                # if not supported any full cudagraphs, just raise it.
                msg += "; please try cudagraph_mode=PIECEWISE, and "\
                    "make sure compilation level is piecewise"
                raise ValueError(msg)

            # attempt to resolve the full cudagraph related mode
            if self.compilation_config.splitting_ops_contain_attention():
                msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
                cudagraph_mode = self.compilation_config.cudagraph_mode = \
                    CUDAGraphMode.FULL_AND_PIECEWISE
            else:
                msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
                cudagraph_mode = self.compilation_config.cudagraph_mode = \
                    CUDAGraphMode.FULL_DECODE_ONLY
            logger.warning(msg)

        # check that if we are doing spec-decode + decode full-cudagraphs it is
        # supported
        if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
                and self.uniform_decode_query_len > 1 and min_cg_support.value
                < AttentionCGSupport.UNIFORM_BATCH.value):
            msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
                   f" with spec-decode for attention backend "
                   f"{min_cg_builder_name} (support: {min_cg_support})")
            if self.compilation_config.splitting_ops_contain_attention():
                msg += "; setting cudagraph_mode=PIECEWISE"
                cudagraph_mode = self.compilation_config.cudagraph_mode = \
                    CUDAGraphMode.PIECEWISE
            else:
                msg += "; setting cudagraph_mode=NONE"
                cudagraph_mode = self.compilation_config.cudagraph_mode = \
                    CUDAGraphMode.NONE
            logger.warning(msg)

        # double check that we can support full cudagraph if they are requested
        # even after automatic downgrades
        if cudagraph_mode.has_full_cudagraphs() \
            and min_cg_support == AttentionCGSupport.NEVER:
            raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not "
                             f"supported with {min_cg_builder_name} backend ("
                             f"support:{min_cg_support}) "
                             "; please try cudagraph_mode=PIECEWISE, "
                             "and make sure compilation level is piecewise")

        # Trigger cudagraph dispatching keys initialization here (after
        # initializing attn backends).
        self.cudagraph_dispatcher.initialize_cudagraph_keys(
            self.compilation_config.cudagraph_mode,
            self.uniform_decode_query_len)

2871
2872
2873
2874
2875
    def calculate_reorder_batch_threshold(self) -> None:
        """
        Check that if any backends reorder batches; that the reordering
        is compatible (e.g., decode threshold is the same)
        """
2876
2877
2878
        for group in self._attn_group_iterator():
            attn_metadata_builder_i = group.metadata_builder

2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
            # check that if any backends reorder batches; that the reordering
            # is compatible (e.g., decode threshold is the same)
            reorder_batch_threshold_i = (
                attn_metadata_builder_i.reorder_batch_threshold)
            if reorder_batch_threshold_i is not None:
                if self.reorder_batch_threshold is not None:
                    if reorder_batch_threshold_i != \
                        self.reorder_batch_threshold:
                        raise ValueError(
                            f"Attention backend reorders decodes with "
                            f"threshold {reorder_batch_threshold_i} but other "
                            f"backend uses threshold "
                            f"{self.reorder_batch_threshold}")
                else:
                    self.reorder_batch_threshold = reorder_batch_threshold_i

2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
    def may_reinitialize_input_batch(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Re-initialize the input batch if the block sizes are different from
        `[self.cache_config.block_size]`. This usually happens when there
        are multiple KV cache groups.

        Args:
            kv_cache_config: The KV cache configuration.
        """
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]
        if block_sizes != [self.cache_config.block_size]:
            assert self.cache_config.cpu_offload_gb == 0, (
                "Cannot re-initialize the input batch when CPU weight "
                "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 "  # noqa: E501
                "for more details.")
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
                max_model_len=self.max_model_len,
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
                block_sizes=block_sizes,
2922
                is_spec_decode=bool(self.vllm_config.speculative_config),
2923
2924
                logitsprocs=self.input_batch.logitsprocs,
                is_pooling_model=self.is_pooling_model,
2925
2926
            )

2927
2928
    def _allocate_kv_cache_tensors(
            self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
2929
        """
2930
2931
2932
        Initializes the KV cache buffer with the correct size. The buffer needs
        to be reshaped to the desired shape before being used by the models.

2933
        Args:
2934
            kv_cache_config: The KV cache config
2935
        Returns:
2936
            dict[str, torch.Tensor]: A map between layer names to their
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
            corresponding memory buffer for KV cache.
         """
        kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            tensor = torch.zeros(kv_cache_tensor.size,
                                 dtype=torch.int8,
                                 device=self.device)
            for layer_name in kv_cache_tensor.shared_by:
                kv_cache_raw_tensors[layer_name] = tensor

        layer_names = set()
        for group in kv_cache_config.kv_cache_groups:
2949
2950
2951
2952
            for layer_name in group.layer_names:
                if layer_name in self.runner_only_attn_layers:
                    continue
                layer_names.add(layer_name)
2953
2954
2955
2956
        assert layer_names == set(kv_cache_raw_tensors.keys(
        )), "Some layers are not correctly initialized"
        return kv_cache_raw_tensors

2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
    def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
        return itertools.chain.from_iterable(self.attn_groups)

    def _kv_cache_spec_attn_group_iterator(
            self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
        if not self.kv_cache_config.kv_cache_groups:
            return
        for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
            for attn_group in attn_groups:
                yield self.kv_cache_config.kv_cache_groups[
                    kv_cache_spec_id].kv_cache_spec, attn_group

2969
2970
2971
2972
2973
    def _reshape_kv_cache_tensors(
        self,
        kv_cache_config: KVCacheConfig,
        kv_cache_raw_tensors: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
2974
        """
2975
        Reshape the KV cache tensors to the desired shape and dtype.
2976

2977
        Args:
2978
2979
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer, with
2980
2981
            correct size but uninitialized shape.
        Returns:
2982
            Dict[str, torch.Tensor]: A map between layer names to their
2983
2984
            corresponding memory buffer for KV cache.
        """
2985
        kv_caches: dict[str, torch.Tensor] = {}
2986
        has_attn, has_mamba = False, False
2987
2988
2989
        for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
            attn_backend = group.backend
            for layer_name in group.layer_names:
2990
2991
                if layer_name in self.runner_only_attn_layers:
                    continue
2992
2993
2994
2995
                raw_tensor = kv_cache_raw_tensors[layer_name]
                assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
                num_blocks = (raw_tensor.numel() //
                              kv_cache_spec.page_size_bytes)
2996
                if isinstance(kv_cache_spec, AttentionSpec):
2997
                    has_attn = True
2998
                    kv_cache_shape = attn_backend.get_kv_cache_shape(
2999
3000
3001
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype
3002
                    try:
3003
3004
                        kv_cache_stride_order = \
                            attn_backend.get_kv_cache_stride_order()
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
                        assert len(kv_cache_stride_order) == len(
                            kv_cache_shape)
                    except (AttributeError, NotImplementedError):
                        kv_cache_stride_order = tuple(
                            range(len(kv_cache_shape)))
                    # The allocation respects the backend-defined stride order
                    # to ensure the semantic remains consistent for each
                    # backend. We first obtain the generic kv cache shape and
                    # then permute it according to the stride order which could
                    # result in a non-contiguous tensor.
                    kv_cache_shape = tuple(kv_cache_shape[i]
                                           for i in kv_cache_stride_order)
                    # Maintain original KV shape view.
                    inv_order = [
                        kv_cache_stride_order.index(i)
                        for i in range(len(kv_cache_stride_order))
                    ]
3022
3023
3024
                    kv_caches[layer_name] = kv_cache_raw_tensors[
                        layer_name].view(dtype).view(kv_cache_shape).permute(
                            *inv_order)
Chen Zhang's avatar
Chen Zhang committed
3025
                elif isinstance(kv_cache_spec, MambaSpec):
3026
                    has_mamba = True
Chen Zhang's avatar
Chen Zhang committed
3027
3028
                    raw_tensor = kv_cache_raw_tensors[layer_name]
                    state_tensors = []
3029
3030
3031
3032
3033
3034
                    storage_offset_bytes = 0
                    for (shape, dtype) in zip(kv_cache_spec.shapes,
                                              kv_cache_spec.dtypes):
                        dtype_size = get_dtype_size(dtype)
                        num_element_per_page = (
                            kv_cache_spec.page_size_bytes // dtype_size)
Chen Zhang's avatar
Chen Zhang committed
3035
                        target_shape = (num_blocks, *shape)
3036
3037
                        stride = torch.empty(target_shape).stride()
                        target_stride = (num_element_per_page, *stride[1:])
3038
                        assert storage_offset_bytes % dtype_size == 0
3039
3040
3041
3042
                        tensor = torch.as_strided(
                            raw_tensor.view(dtype),
                            size=target_shape,
                            stride=target_stride,
3043
                            storage_offset=storage_offset_bytes // dtype_size,
3044
                        )
Chen Zhang's avatar
Chen Zhang committed
3045
                        state_tensors.append(tensor)
3046
                        storage_offset_bytes += stride[0] * dtype_size
3047
3048

                    kv_caches[layer_name] = state_tensors
3049
                else:
3050
                    raise NotImplementedError
3051
3052
3053
3054
3055

        if has_attn and has_mamba:
            self._verify_hybrid_attention_mamba_layout(kv_cache_config,
                                                       kv_cache_raw_tensors)

3056
3057
        return kv_caches

3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
    def _verify_hybrid_attention_mamba_layout(
            self, kv_cache_config: KVCacheConfig,
            kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
        """
        Verify that the KV cache memory layout is compatible for
        models with both attention and mamba KV cache groups.

        Args:
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer.
        """

3070
3071
        for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
            for layer_name in group.layer_names:
3072
3073
3074
3075
                raw_tensor = kv_cache_raw_tensors[layer_name]
                num_blocks = (raw_tensor.numel() //
                              kv_cache_spec.page_size_bytes)
                if isinstance(kv_cache_spec, AttentionSpec):
3076
3077

                    kv_cache_shape = group.backend.get_kv_cache_shape(
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    if kv_cache_shape[0] != num_blocks or kv_cache_shape[
                            1] != 2:
                        raise ValueError(
                            "Hybrid models in V1 require an attention "
                            "backend with kv_cache_shape="
                            "(num_blocks, 2, ...). Please try setting "
                            "VLLM_ATTENTION_BACKEND=FLASHINFER")

3088
3089
3090
3091
3092
3093
3094
3095
    def initialize_kv_cache_tensors(
            self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
        """
        Initialize the memory buffer for KV cache.

        Args:
            kv_cache_config: The KV cache config
        Returns:
3096
            Dict[str, torch.Tensor]: A map between layer names to their
3097
3098
3099
3100
3101
3102
3103
            corresponding memory buffer for KV cache.
        """
        # Initialize the memory buffer for KV cache
        kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
        # Change the memory buffer to the desired shape
        kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
                                                   kv_cache_raw_tensors)
3104

3105
3106
3107
3108
3109
3110
3111
        # Setup `kv_cache_config` and `kv_caches` for models
        # with cross-layer KV sharing
        if self.shared_kv_cache_layers:
            initialize_kv_cache_for_kv_sharing(
                self.shared_kv_cache_layers,
                kv_cache_config.kv_cache_groups,
                kv_caches,
3112
                self.attn_groups,
3113
                self.runner_only_attn_layers,
3114
            )
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
            attn_layers = get_layers_from_vllm_config(self.vllm_config,
                                                      Attention)
            # Iterate in reversed order and add layers that re-use KV cache
            # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
            for layer_name in reversed(attn_layers):
                if layer_name in self.shared_kv_cache_layers:
                    self.kv_sharing_fast_prefill_eligible_layers.add(
                        layer_name)
                else:
                    break
3125

3126
3127
3128
        bind_kv_cache(kv_caches,
                      self.compilation_config.static_forward_context,
                      self.kv_caches)
3129
3130
3131
3132
3133
3134
3135
3136
3137
        return kv_caches

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV
            cache size of each layer
        """
3138
        kv_cache_config = deepcopy(kv_cache_config)
3139
3140
        self.kv_cache_config = kv_cache_config
        self.may_reinitialize_input_batch(kv_cache_config)
3141
        self.may_add_encoder_only_layers_to_kv_cache_config()
3142
3143
3144
        self.initialize_attn_backend(kv_cache_config)
        kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)

3145
3146
3147
3148
3149
3150
        if self.speculative_config and self.speculative_config.use_eagle():
            assert isinstance(self.drafter, EagleProposer)
            # validate all draft model layers belong to the same kv cache
            # group
            self.drafter.validate_same_kv_cache_group(kv_cache_config)

Robert Shaw's avatar
Robert Shaw committed
3151
3152
3153
        if has_kv_transfer_group():
            get_kv_transfer_group().register_kv_caches(kv_caches)

3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
    def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
        """
        Add encoder-only layers to the KV cache config.
        """
        block_size = self.vllm_config.cache_config.block_size
        use_mla = self.vllm_config.model_config.use_mla
        encoder_only_attn_specs: dict[AttentionSpec,
                                      list[str]] = defaultdict(list)
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
        for layer_name, attn_module in attn_layers.items():
            if attn_module.attn_type == AttentionType.ENCODER_ONLY:
                attn_spec = EncoderOnlyAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=attn_module.num_kv_heads,
                    head_size=attn_module.head_size,
                    dtype=self.kv_cache_dtype,
                    use_mla=use_mla)
                encoder_only_attn_specs[attn_spec].append(layer_name)
                self.runner_only_attn_layers.add(layer_name)
        if len(encoder_only_attn_specs) > 0:
            assert len(
                encoder_only_attn_specs
            ) == 1, "Only support one encoder-only attention spec now"
            spec, layer_names = encoder_only_attn_specs.popitem()
            self.kv_cache_config.kv_cache_groups.append(
                KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))

3181
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
3182
        """
3183
        Generates the KVCacheSpec by parsing the kv cache format from each
3184
3185
        Attention module in the static forward context.
        Returns:
3186
            KVCacheSpec: A dictionary mapping layer names to their KV cache
3187
3188
3189
3190
            format. Layers that do not need KV cache are not included.
        """

        block_size = self.vllm_config.cache_config.block_size
3191
        use_mla = self.vllm_config.model_config.use_mla
3192
        kv_cache_spec: dict[str, KVCacheSpec] = {}
Chen Zhang's avatar
Chen Zhang committed
3193
3194
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
        for layer_name, attn_module in attn_layers.items():
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
            if (kv_tgt_layer :=
                    attn_module.kv_sharing_target_layer_name) is not None:
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                continue

3207
            # TODO: Support other attention modules, e.g., cross-attention
3208
3209
            # TODO(lucas): move the attention specs into the model layers like
            # the attention backends
3210
            if attn_module.attn_type == AttentionType.DECODER:
3211
3212
3213
3214
3215
3216
3217
3218
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=use_mla)
3219
3220
                elif self.attention_chunk_size is not None \
                        and isinstance(attn_module, ChunkedLocalAttention):
3221
                    kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
3222
3223
3224
3225
3226
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        attention_chunk_size=self.attention_chunk_size,
3227
                        use_mla=use_mla)
3228
3229
3230
3231
3232
3233
3234
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        use_mla=use_mla)
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

3245
        mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
Chen Zhang's avatar
Chen Zhang committed
3246
3247
3248
3249
3250
3251
3252
3253
        if len(mamba_layers) > 0:
            if self.vllm_config.speculative_config is not None:
                raise NotImplementedError(
                    "Mamba with speculative decoding is not supported yet.")
            if self.vllm_config.cache_config.enable_prefix_caching:
                raise NotImplementedError(
                    "Prefix caching is not supported for Mamba yet.")
            max_model_len = self.vllm_config.model_config.max_model_len
3254

3255
3256
            page_size_padded = (
                self.vllm_config.cache_config.mamba_page_size_padded)
3257

Chen Zhang's avatar
Chen Zhang committed
3258
3259
3260
3261
3262
            # Set block_size to max_model_len, so that mamba model will always
            # have only one block in the KV cache.
            for layer_name, mamba_module in mamba_layers.items():
                kv_cache_spec[layer_name] = MambaSpec(
                    shapes=mamba_module.get_state_shape(),
3263
                    dtypes=mamba_module.get_state_dtype(),
3264
                    block_size=max_model_len,
3265
3266
                    page_size_padded=page_size_padded,
                    mamba_type=mamba_module.mamba_type)
3267

3268
        return kv_cache_spec