gpu_model_runner.py 125 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Robert Shaw's avatar
Robert Shaw committed
4
import copy
5
import gc
6
import time
7
import weakref
8
from contextlib import contextmanager
9
from typing import TYPE_CHECKING, Any, Optional, Union
10
11
12
13
14

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
15
from tqdm import tqdm
16

17
import vllm.envs as envs
18
from vllm.attention import AttentionType, get_attn_backend
19
from vllm.attention.backends.abstract import AttentionBackend
20
from vllm.attention.layer import Attention
21
from vllm.compilation.counter import compilation_counter
22
23
from vllm.config import (CompilationLevel, VllmConfig,
                         get_layers_from_vllm_config)
24
from vllm.distributed.eplb.eplb_state import EplbState
25
26
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group)
Robert Shaw's avatar
Robert Shaw committed
27
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
28
from vllm.distributed.parallel_state import (
29
    get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
30
    prepare_communication_buffer_for_model)
31
32
from vllm.forward_context import (DPMetadata, get_forward_context,
                                  set_forward_context)
33
from vllm.logger import init_logger
Chen Zhang's avatar
Chen Zhang committed
34
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
35
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
36
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
37
38
from vllm.model_executor.models.interfaces import (has_step_pooler,
                                                   is_mixture_of_experts)
39
40
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
41
from vllm.multimodal.utils import group_mm_inputs_by_modality
42
from vllm.pooling_params import PoolingParams
43
from vllm.sampling_params import SamplingType
44
from vllm.sequence import IntermediateTensors
45
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
46
                        GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
Chen Zhang's avatar
Chen Zhang committed
47
                        check_use_alibi, get_dtype_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
48
                        is_pin_memory_available, round_up)
Chen Zhang's avatar
Chen Zhang committed
49
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
50
51
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
                                              CommonAttentionMetadata)
52
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
53
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
Chen Zhang's avatar
Chen Zhang committed
54
                                        KVCacheConfig, KVCacheSpec, MambaSpec,
55
                                        SlidingWindowSpec)
56
57
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
                             ModelRunnerOutput)
58
from vllm.v1.pool.metadata import PoolingMetadata
59
from vllm.v1.sample.metadata import SamplingMetadata
60
from vllm.v1.sample.rejection_sampler import RejectionSampler
61
from vllm.v1.sample.rejection_sampler_mtp import MtpRejectionSampler
62
from vllm.v1.sample.sampler import Sampler
63
from vllm.v1.spec_decode.eagle import EagleProposer
64
from vllm.v1.spec_decode.medusa import MedusaProposer
65
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
66
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
67
from vllm.v1.spec_decode.utils import DraftProbs
68
from vllm.v1.utils import bind_kv_cache
69
from vllm.v1.worker.block_table import BlockTable
70
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
71
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
zhuwenwen's avatar
zhuwenwen committed
72
from vllm.platforms import current_platform
73

74
from ..sample.logits_processor import LogitsProcessorManager
75
76
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
                    sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
77

78
if TYPE_CHECKING:
79
    import xgrammar as xgr
80
    import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile  # noqa: E501
81

82
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
83
    from vllm.v1.core.sched.output import SchedulerOutput
84
85
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
86
87
88
    xgr_torch_compile = LazyLoader(
        "xgr_torch_compile", globals(),
        "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile")
89
90
91
92

logger = init_logger(__name__)


93
class GPUModelRunner(LoRAModelRunnerMixin):
94
95
96

    def __init__(
        self,
97
        vllm_config: VllmConfig,
98
        device: torch.device,
99
    ):
100
101
102
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
103
        self.compilation_config = vllm_config.compilation_config
104
105
106
107
108
109
110
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
111

112
113
114
115
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

116
117
118
119
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
120
        self.device = device
121
122
123
124
125
126
127
128
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

129
        self.is_multimodal_model = model_config.is_multimodal_model
130
        self.is_pooling_model = model_config.pooler_config is not None
131
132
        self.max_model_len = model_config.max_model_len
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
133
        self.max_num_reqs = scheduler_config.max_num_seqs
134
135

        # Model-related.
136
137
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
138
        self.hidden_size = model_config.get_hidden_size()
139
        self.attention_chunk_size = model_config.attention_chunk_size
140

141
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
142

143
        # Multi-modal data support
144
        self.mm_registry = MULTIMODAL_REGISTRY
145
        self.uses_mrope = model_config.uses_mrope
146

147
148
149
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
150
            mm_registry=self.mm_registry,
151
152
153
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
154

155
156
157
        # Sampler
        self.sampler = Sampler()

158
159
160
161
162
163
164
        self.eplb_state: Optional[EplbState] = None
        """
        State of the expert parallelism load balancer.

        Will be lazily initialized when the model is loaded.
        """

165
        # Lazy initializations
166
        # self.model: nn.Module  # Set after load_model
167
        # Initialize in initialize_kv_cache
168
        self.kv_caches: list[torch.Tensor] = []
169
170
        self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
        self.attn_backends: list[type[AttentionBackend]] = []
171
172
        # self.kv_cache_config: KVCacheConfig

173
        # req_id -> (input_id -> encoder_output)
174
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
175

176
        self.use_aux_hidden_state_outputs = False
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # Set up speculative decoding.
        # NOTE(Jiayi): currently we put the entire draft model on
        # the last PP rank. This is not ideal if there are many
        # layers in the draft model.
        if self.speculative_config and get_pp_group().is_last_rank:
            if self.speculative_config.method == "ngram":
                self.drafter = NgramProposer(self.vllm_config)
            elif self.speculative_config.use_eagle():
                self.drafter = EagleProposer(self.vllm_config, self.device,
                                             self)  # type: ignore
                if self.speculative_config.method == "eagle3":
                    self.use_aux_hidden_state_outputs = True
            elif self.speculative_config.method == "medusa":
                self.drafter = MedusaProposer(
                    vllm_config=self.vllm_config,
                    device=self.device)  # type: ignore
            else:
                raise ValueError("Unknown speculative decoding method: "
                                 f"{self.speculative_config.method}")
196
197
198
199
200
201
            
            self.use_mtp = self.speculative_config.method == "deepseek_mtp"
            if not self.use_mtp:
                self.rejection_sampler = RejectionSampler()
            else:
                self.rejection_sampler = MtpRejectionSampler()
202

203
        # Request states.
204
        self.requests: dict[str, CachedRequestState] = {}
205

206
207
208
209
210
211
212
213
214
        # Input Batch
        # NOTE(Chen): Ideally, we should initialize the input batch inside
        # `initialize_kv_cache` based on the kv cache config. However, as in
        # https://github.com/vllm-project/vllm/pull/18298, due to some unknown
        # reasons, we have to initialize the input batch before `load_model`,
        # quantization + weight offloading will fail otherwise. As a temporary
        # solution, we initialize the input batch here, and re-initialize it
        # in `initialize_kv_cache` if the block_sizes here is different from
        # the block_sizes in the kv cache config.
215
216
217
218
219
220
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
221
            vocab_size=self.model_config.get_vocab_size(),
222
            block_sizes=[self.cache_config.block_size],
223
            is_spec_decode=bool(self.vllm_config.speculative_config),
224
        )
225

226
227
228
229
230
        self.use_cuda_graph = (
            self.vllm_config.compilation_config.level
            == CompilationLevel.PIECEWISE
            and self.vllm_config.compilation_config.use_cudagraph
            and not self.model_config.enforce_eager)
231
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
232
233
234
235
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
236
237
238
            reversed(self.compilation_config.cudagraph_capture_sizes))

        self.full_cuda_graph = self.compilation_config.full_cuda_graph
239

240
        # Cache the device properties.
241
        self._init_device_properties()
242

243
244
245
246
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
247
248
249
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
250
251
252
253
254
255
256
257
258
259
        self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
                                           dtype=torch.int32,
                                           device=self.device)
        self.seq_lens = torch.zeros(self.max_num_reqs,
                                    dtype=torch.int32,
                                    device=self.device)
        self.slot_mapping = torch.zeros(self.max_num_tokens,
                                        dtype=torch.int64,
                                        device=self.device)

260
261
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
262
263

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
264
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
265
266
267
268
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
269
270
271
272
273
274

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
275
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
276
277
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
278
279
280
281
282
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
283
            self.mrope_positions_np = self.mrope_positions_cpu.numpy()
284

285
286
287
        # Only relevant for models using ALiBi (e.g, MPT)
        self.use_alibi = check_use_alibi(model_config)

288
289
290
291
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
292

293
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
294
        # Keep in int64 to avoid overflow with long context
295
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
296
297
                                       self.max_model_len,
                                       self.max_num_tokens),
298
                                   dtype=np.int64)
299
300
301
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
316
317
318
319
320
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
321

322
323
324
325
326
327
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}

328
329
        self.draft_probs : Optional[DraftProbs] = None

330
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
331
332
        """
        Update the order of requests in the batch based on the attention
333
        backend's needs. For example, some attention backends (namely MLA) may
334
335
336
337
338
339
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            scheduler_output: The scheduler output.
        """
340
341
        self.attn_metadata_builders[0].reorder_batch(self.input_batch,
                                                     scheduler_output)
342
343
344
345
346

        # For models with multiple KV cache groups, the groups should agree on
        # the same order of requests. We ensure this by only allowing the first
        # group to reorder the batch and asserting that all other groups do not
        # reorder the batch.
347
348
349
        # TODO(tdoublep): make this more flexible so that any group can
        # re-order the batch (not only the first).
        # TODO(tdoublep): verify this during engine init instead of at runtime
350
        for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
351
            batch_reordered = self.attn_metadata_builders[i].reorder_batch(
352
                self.input_batch, scheduler_output)
353
            assert not batch_reordered
354

355
356
357
358
359
360
361
362
363
364
365
    # Note: used for model runner override.
    def _init_device_properties(self) -> None:
        """Initialize attributes from torch.cuda.get_device_properties
        """
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

    # Note: used for model runner override.
    def _sync_device(self) -> None:
        torch.cuda.synchronize()

366
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
367
368
369
370
371
372
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

373
374
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
375
376
        """
        # Remove finished requests from the cached states.
377
378
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
379
            self.encoder_cache.pop(req_id, None)
380
381
382
383
384
385
386
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        for req_id in scheduler_output.finished_req_ids:
387
            self.input_batch.remove_request(req_id)
388

389
390
391
392
        # prune draft probs of finished requests
        if self.use_mtp and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
            self.draft_probs.prune(list(scheduler_output.finished_req_ids))

393
394
395
396
397
398
399
        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
400

401
402
403
404
405
406
407
408
409
410
411
412
413
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
414
            self.input_batch.remove_request(req_id)
415

416
        req_ids_to_add: list[str] = []
417
        # Add new requests to the cached states.
418
419
420
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
421
422
423
            pooling_params = new_req_data.pooling_params
            if sampling_params and \
                sampling_params.sampling_type == SamplingType.RANDOM_SEED:
424
425
426
427
428
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

429
430
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
431
432
433
                prompt_token_ids=new_req_data.prompt_token_ids,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
434
                sampling_params=sampling_params,
435
                pooling_params=pooling_params,
436
                generator=generator,
437
438
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
439
                output_token_ids=[],
440
                lora_request=new_req_data.lora_request,
441
            )
442
443

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
444
            if self.uses_mrope:
445
446
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
447
                second_per_grid_ts = []
448
449
                audio_feature_lengths = []
                use_audio_in_video = False
450
451
452
453
454
455
456
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
457
458
459
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
460
461
462
463
464
                    if mm_input.get("audio_feature_lengths") is not None:
                        audio_feature_lengths.extend(
                            mm_input["audio_feature_lengths"])
                    if mm_input.get("use_audio_in_video") is True:
                        use_audio_in_video = True
465
466
467
468
469
470
471

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
472
                        hf_config=hf_config,
473
474
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
475
                        second_per_grid_ts=second_per_grid_ts,
476
477
                        audio_feature_lengths=audio_feature_lengths,
                        use_audio_in_video=use_audio_in_video,
478
479
                    )

480
481
            req_ids_to_add.append(req_id)

482
        # Update the states of the running/resumed requests.
483
        is_last_rank = get_pp_group().is_last_rank
484
485
        req_data = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(req_data.req_ids):
486
            req_state = self.requests[req_id]
487
488
489
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
            resumed_from_preemption = req_data.resumed_from_preemption[i]
490

491
            # Update the cached states.
492
            req_state.num_computed_tokens = num_computed_tokens
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509

            if not is_last_rank:
                # When using PP, the scheduler sends the sampled tokens back,
                # because there's no direct communication between the first-
                # stage worker and the last-stage worker.
                new_token_ids = req_data.new_token_ids[i]
                # Add the sampled token(s) from the previous step (if any).
                # This doesn't include "unverified" tokens like spec tokens.
                num_new_tokens = (num_computed_tokens + len(new_token_ids) -
                                  req_state.num_tokens)
                if num_new_tokens == 1:
                    # Avoid slicing list in most common case.
                    req_state.output_token_ids.append(new_token_ids[-1])
                elif num_new_tokens > 0:
                    req_state.output_token_ids.extend(
                        new_token_ids[-num_new_tokens:])

510
            # Update the block IDs.
511
            if not resumed_from_preemption:
512
                # Append the new blocks to the existing block IDs.
513
514
515
                for block_ids, new_ids in zip(req_state.block_ids,
                                              new_block_ids):
                    block_ids.extend(new_ids)
516
517
518
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
519
                req_state.block_ids = new_block_ids
520
521
522
523
524
525
526
527
528
529
530

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
531
                num_computed_tokens)
532
            self.input_batch.block_table.append_row(new_block_ids, req_index)
533
534
535
536
537
538
539

            # For the last rank, we don't need to update the token_ids_cpu
            # because the sampled tokens are already cached.
            if not is_last_rank:
                # Add new_token_ids to token_ids_cpu.
                start_token_index = num_computed_tokens
                end_token_index = num_computed_tokens + len(new_token_ids)
540
                self.input_batch.token_ids_cpu[
541
542
543
544
545
                    req_index,
                    start_token_index:end_token_index] = new_token_ids
                self.input_batch.num_tokens_no_spec[
                    req_index] = end_token_index
                self.input_batch.num_tokens[req_index] = end_token_index
546

547
            # Add spec_token_ids to token_ids_cpu.
548
549
            spec_token_ids = (
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
550
            
551
            if spec_token_ids:
552
553
554
                num_spec_tokens = len(spec_token_ids)
                start_index = self.input_batch.num_tokens_no_spec[req_index]
                end_token_index = start_index + num_spec_tokens
555
556
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
557
558
                # NOTE(woosuk): `num_tokens` here may include spec tokens.
                self.input_batch.num_tokens[req_index] += num_spec_tokens
559

560
561
562
563
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
564
            self.input_batch.add_request(req_state)
565

566
567
568
569
570
571
        # Condense the batched states if there are gaps left by removed requests
        self.input_batch.condense()
        # Allow attention backend to reorder the batch, potentially
        self._may_reorder_batch(scheduler_output)
        # Refresh batch metadata with any pending updates.
        self.input_batch.refresh_metadata()
572

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    def _get_cumsum_and_arange(
        self,
        num_tokens: np.ndarray,
        cumsum_dtype: Optional[np.dtype] = None,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Get the cumulative sum and batched arange of the given array.
        # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_tokens])
        """
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
        total_num_tokens = cu_num_tokens[-1]
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_tokens] - cumsums_offsets

        return cu_num_tokens, arange

593
    def _prepare_inputs(
594
595
        self,
        scheduler_output: "SchedulerOutput",
596
    ) -> tuple[dict[str, Any], bool, torch.Tensor,
597
               Optional[SpecDecodeMetadata], np.ndarray]:
598
599
600
601
602
603
604
        """
        :return: tuple[
            attn_metadata: layer-to-attention_metadata mapping,
            attention_cuda_graphs: whether attention can run in cudagraph
            logits_indices, spec_decode_metadata
        ]
        """
605
606
607
608
609
610
611
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
612
        self.input_batch.block_table.commit(num_reqs)
613
614

        # Get the number of scheduled tokens for each request.
615
616
617
618
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
619
620
621

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
622
623
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
624

625
626
627
628
        # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
        # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        cu_num_tokens, arange = self._get_cumsum_and_arange(
            num_scheduled_tokens)
629
630

        # Get positions.
631
        positions_np = self.positions_np[:total_num_scheduled_tokens]
632
633
634
635
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

636
637
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
638
        if self.uses_mrope:
639
640
            self._calc_mrope_positions(scheduler_output)

641
642
643
644
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
645
646
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
647

648
649
650
651
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
652
                           0,
653
654
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
655

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
        # Calculate the slot mapping for each KV cache group.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):
            block_size = kv_cache_group_spec.kv_cache_spec.block_size
            block_table: BlockTable = self.input_batch.block_table[
                kv_cache_group_id]
            # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
            # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
            # where K is the max_num_blocks_per_req and the block size is 2.
            # NOTE(woosuk): We can't simply use `token_indices // block_size`
            # here because M (max_model_len) is not necessarily divisible by
            # block_size.
            block_table_indices = (
                req_indices * block_table.max_num_blocks_per_req +
                positions_np // block_size)
            block_table_cpu = block_table.get_cpu_tensor()
            block_numbers = block_table_cpu.flatten(
            )[block_table_indices].numpy()
            block_offsets = positions_np % block_size
            np.add(
                block_numbers * block_size,
                block_offsets,
                out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
679
680

        # Prepare the attention metadata.
681
        self.query_start_loc_np[0] = 0
682
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
683

684
685
686
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
687
688
689
690

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
691
        if self.uses_mrope:
692
693
694
695
696
697
698
699
700
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
701

702
703
704
705
706
707
708
        self.query_start_loc[:num_reqs + 1].copy_(
            self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
        self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
                                       non_blocking=True)

        # Fill unused with -1. Needed for reshape_and_cache
        self.seq_lens[num_reqs:].fill_(0)
709
710
711
712
        # Note: pad query_start_loc to be non-decreasing, as kernels
        # like FlashAttention requires that
        self.query_start_loc[num_reqs + 1:].fill_(
            self.query_start_loc_cpu[num_reqs].item())
713
714
715
716

        query_start_loc = self.query_start_loc[:num_reqs + 1]
        seq_lens = self.seq_lens[:num_reqs]

717
        common_attn_metadata = CommonAttentionMetadata(
718
719
720
721
722
723
            query_start_loc=query_start_loc,
            seq_lens=seq_lens,
            num_reqs=num_reqs,
            num_actual_tokens=total_num_scheduled_tokens,
            max_query_len=max_num_scheduled_tokens,
        )
724

725
        attn_metadata: dict[str, Any] = {}
726
727
728
729
730
731
732
        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):

            # Prepare for cascade attention if enabled & beneficial.
            common_prefix_len = 0
733
            builder = self.attn_metadata_builders[kv_cache_group_id]
734
735
736
            if self.cascade_attn_enabled:
                common_prefix_len = self._compute_cascade_attn_prefix_len(
                    num_scheduled_tokens,
737
738
739
                    scheduler_output.
                    num_common_prefix_blocks[kv_cache_group_id],
                    kv_cache_group_spec.kv_cache_spec,
740
                    builder,
741
                )
742

743
744
745
746
747
            attn_metadata_i = (builder.build(
                common_prefix_len=common_prefix_len,
                common_attn_metadata=common_attn_metadata,
            ))

748
749
            for layer_name in kv_cache_group_spec.layer_names:
                attn_metadata[layer_name] = attn_metadata_i
750

751
752
753
754
        attention_cuda_graphs = all(
            b.can_run_in_cudagraph(common_attn_metadata)
            for b in self.attn_metadata_builders)

755
756
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
757
        if not use_spec_decode:
758
759
760
761
762
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
763
            logits_indices = query_start_loc[1:] - 1
764
765
766
767
768
769
770
771
772
773
774
775
776
777
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

            spec_decode_metadata = self._calc_spec_decode_metadata(
                num_draft_tokens, cu_num_tokens)
            logits_indices = spec_decode_metadata.logits_indices
778

779
780
781
782
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

783
        return (attn_metadata, attention_cuda_graphs, logits_indices,
784
                spec_decode_metadata, num_scheduled_tokens)
785

786
787
788
789
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
790
791
        kv_cache_spec: KVCacheSpec,
        attn_metadata_builder: AttentionMetadataBuilder,
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
810
        common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
848
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
849
850
851
852
853
854
855
856
857
858
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
859
860
861
862
863
864
865
        common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
                             kv_cache_spec.block_size)
        use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
                              (isinstance(kv_cache_spec, FullAttentionSpec)
                               and kv_cache_spec.sliding_window is not None))
        assert isinstance(kv_cache_spec, AttentionSpec)
        use_cascade = attn_metadata_builder.use_cascade_attention(
866
867
868
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
869
            num_kv_heads=kv_cache_spec.num_kv_heads,
870
            use_alibi=self.use_alibi,
871
            use_sliding_window=use_sliding_window,
872
873
874
875
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

876
877
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
878
        for index, req_id in enumerate(self.input_batch.req_ids):
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

916
917
918
919
920
921
922
                MRotaryEmbedding.get_next_input_positions_tensor(
                    out=self.mrope_positions_np,
                    out_offset=dst_start,
                    mrope_position_delta=req.mrope_position_delta,
                    context_len=num_computed_tokens + prompt_part_len,
                    num_new_tokens=completion_part_len,
                )
923
924
925

                mrope_pos_ptr += completion_part_len

926
927
    def _calc_spec_decode_metadata(
        self,
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
944
945
946
947
948
949

        # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
        # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
            num_sampled_tokens, cumsum_dtype=np.int32)
        # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
950
951
        logits_indices = np.repeat(
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
952
        # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
953
954
955
956
957
958
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
959
960
961
962
        # cu_num_draft_tokens: [3, 3, 5, 5, 6]
        # arange: [0, 1, 2, 0, 1, 0]
        cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
            num_draft_tokens, cumsum_dtype=np.int32)
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
            self.device, non_blocking=True)
        logits_indices = torch.from_numpy(logits_indices).to(self.device,
                                                             non_blocking=True)
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
            self.device, non_blocking=True)
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
977
978
            self.device, non_blocking=True)

979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        draft_token_ids = self.input_ids[logits_indices]
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

        metadata = SpecDecodeMetadata(
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
        return metadata

994
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
995
996
997
998
999
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
1000
1001
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
1002
1003
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
1004
1005
1006
1007
1008

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
1021
1022
            batched_mm_inputs = MultiModalKwargs.batch(
                grouped_mm_inputs, pin_memory=self.pin_memory)
1023
1024
1025
1026
            batched_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_mm_inputs,
                device=self.device,
            )
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

1038
1039
1040
1041
1042
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

1043
1044
            for output in curr_group_outputs:
                encoder_outputs.append(output)
1045
1046

        # Cache the encoder outputs.
1047
1048
1049
1050
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
1051
1052
1053
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}

1054
1055
1056
1057
1058
1059
            self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
1060
1061
        self,
        scheduler_output: "SchedulerOutput",
1062
    ) -> list[torch.Tensor]:
1063
        mm_embeds: list[torch.Tensor] = []
1064
        for req_id in self.input_batch.req_ids:
1065
1066
1067
1068
1069
1070
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
1071
1072
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
1104

1105
1106
1107
    def get_model(self) -> nn.Module:
        return self.model

1108
1109
1110
1111
1112
1113
1114
1115
1116
    def apply_grammar_bitmask(
        self,
        scheduler_output: "SchedulerOutput",
        logits: torch.Tensor,
    ):
        grammar_bitmask = scheduler_output.grammar_bitmask
        if grammar_bitmask is None:
            return

1117
1118
1119
1120
1121
1122
1123
1124
1125
        # We receive the structured output bitmask from the scheduler,
        # compacted to contain bitmasks only for structured output requests.
        # The order of the requests in the bitmask is not guaranteed to be the
        # same as the order of the requests in the gpu runner's batch. We need
        # to sort the bitmask to match the order of the requests used here.

        # Get the batch indices of the structured output requests.
        # Keep track of the number of speculative tokens scheduled for every
        # request in the batch, as the logit indices are offset by this amount.
1126
        struct_out_req_batch_indices: dict[str, int] = {}
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
        cumulative_offset = 0
        seq = sorted(self.input_batch.req_id_to_index.items(),
                     key=lambda x: x[1])
        for req_id, batch_index in seq:
            logit_index = batch_index + cumulative_offset
            cumulative_offset += len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            if req_id in scheduler_output.structured_output_request_ids:
                struct_out_req_batch_indices[req_id] = logit_index

        out_indices = []

        # Reorder the bitmask to match the order of the requests in the batch.
        sorted_bitmask = np.zeros_like(grammar_bitmask,
                                       shape=(logits.shape[0],
                                              grammar_bitmask.shape[1]))
        cumulative_index = 0
        seq = sorted(scheduler_output.structured_output_request_ids.items(),
                     key=lambda x: x[1])
        for req_id, _ in seq:
            logit_index = struct_out_req_batch_indices[req_id]
            num_spec_tokens = len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            for i in range(1 + num_spec_tokens):
                sorted_bitmask[logit_index + i] = \
                    grammar_bitmask[cumulative_index + i]
                out_indices.append(logit_index + i)
            cumulative_index += 1 + num_spec_tokens
        grammar_bitmask = sorted_bitmask
1156

1157
1158
        # Serialization of np.ndarray is much more efficient than a tensor,
        # so we receive it in that format.
1159
1160
        grammar_bitmask = torch.from_numpy(grammar_bitmask)

1161
1162
1163
1164
        # Force use of the torch.compile implementation from xgrammar to work
        # around issues with the Triton kernel in concurrent structured output
        # scenarios. See PR #19565 and issues #19493, #18376 for details.
        xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
1165
1166
            logits,
            grammar_bitmask.to(self.device, non_blocking=True),
1167
            indices=out_indices,
1168
1169
        )

1170
1171
1172
1173
1174
1175
1176
    def sync_and_slice_intermediate_tensors(
            self, num_tokens: int, intermediate_tensors: IntermediateTensors,
            sync_self: bool) -> IntermediateTensors:

        assert self.intermediate_tensors is not None

        tp = self.vllm_config.parallel_config.tensor_parallel_size
1177
        enabled_sp = self.compilation_config.pass_config. \
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
            enable_sequence_parallelism
        if enabled_sp:
            # When sequence parallelism is enabled, we always pad num_tokens
            # to be a multiple of tensor_parallel_size (tp) earlier
            assert num_tokens % tp == 0
        is_residual_scattered = tp > 1 and enabled_sp \
            and num_tokens % tp == 0

        # When sequence parallelism is enabled, the "residual" tensor is sharded
        # across tensor parallel ranks, so each rank only needs its own slice.
        if sync_self:
            assert intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
                is_scattered = "residual" and is_residual_scattered
                copy_len = num_tokens // tp if is_scattered else \
                    num_tokens
                self.intermediate_tensors[k][:copy_len].copy_(
                    v[:copy_len], non_blocking=True)

        return IntermediateTensors({
            k:
            v[:num_tokens // tp]
            if k == "residual" and is_residual_scattered else v[:num_tokens]
            for k, v in self.intermediate_tensors.items()
        })

1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
    def eplb_step(self,
                  is_dummy: bool = False,
                  is_profile: bool = False) -> None:
        """
        Step for the EPLB (Expert Parallelism Load Balancing) state.
        """
        if not self.parallel_config.enable_eplb:
            return

        assert self.eplb_state is not None
        assert is_mixture_of_experts(self.model)
        self.eplb_state.step(
            self.model,
            is_dummy,
            is_profile,
            log_stats=self.parallel_config.eplb_log_balancedness,
        )

1222
1223
    def get_dp_padding(self,
                       num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
1224
1225
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1226
1227
1228
1229
1230
1231
1232
1233
1234

        # For DP: Don't pad when setting enforce_eager.
        # This lets us set enforce_eager on the prefiller in a P/D setup and
        # still use CUDA graphs (enabled by this padding) on the decoder.
        #
        # TODO(tms) : There are many cases where padding is enabled for
        # prefills, causing unnecessary and excessive padding of activations.

        if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
1235
            # Early exit.
1236
            return 0, None
1237
1238
1239
1240

        num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
            num_tokens, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
1241
1242
1243
1244
1245
        num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
                                                dp_size,
                                                device="cpu",
                                                dtype=torch.int32)
        return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
1246

1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
    def _pool(
        self,
        hidden_states: torch.Tensor,
        num_scheduled_tokens: int,
        num_scheduled_tokens_np: np.ndarray,
        finished_sending: Optional[set[str]],
        finished_recving: Optional[set[str]],
    ) -> ModelRunnerOutput:
        assert self.input_batch.num_reqs ==\
            len(self.input_batch.pooling_params), \
        "Either all or none of the requests in" \
        " a batch must be pooling request"

        extracted_hidden_states = list(
            torch.split(hidden_states[:num_scheduled_tokens],
                        num_scheduled_tokens_np.tolist()))

        pooling_metadata = self.input_batch.pooling_metadata

        raw_pooler_output = self.model.pooler(
            hidden_states=extracted_hidden_states,
            pooling_metadata=pooling_metadata)

        pooler_output: list[Optional[torch.Tensor]] = []
        seq_lens = self.seq_lens[:self.input_batch.num_reqs]
        for raw_output, seq_len, prompt_len in zip(
                raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):

            if seq_len == prompt_len:
                pooler_output.append(raw_output.data.cpu())
            else:
                pooler_output.append(None)

        return ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=[],
            spec_token_ids=None,
            logprobs=None,
            prompt_logprobs_dict={},
            pooler_output=pooler_output,
            finished_sending=finished_sending,
            finished_recving=finished_recving,
        )

1292
1293
1294
1295
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
1296
        intermediate_tensors: Optional[IntermediateTensors] = None,
1297
    ) -> Union[ModelRunnerOutput, IntermediateTensors]:
1298
        self._update_states(scheduler_output)
1299
        if not scheduler_output.total_num_scheduled_tokens:
Robert Shaw's avatar
Robert Shaw committed
1300
1301
1302
1303
1304
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT

            return self.kv_connector_no_forward(scheduler_output)
1305
1306

        # Prepare the decoder inputs.
1307
        (attn_metadata, attention_cuda_graphs, logits_indices,
1308
1309
         spec_decode_metadata,
         num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
1310
1311
1312
1313
1314
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
1315
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
1316
1317
1318
                num_scheduled_tokens)
        else:
            # Eager mode.
1319
1320
1321
            # Pad tokens to multiple of tensor_parallel_size when
            # enabled collective fusion for SP
            tp_size = self.vllm_config.parallel_config.tensor_parallel_size
1322
            if self.compilation_config.pass_config. \
1323
1324
1325
1326
                enable_sequence_parallelism and tp_size > 1:
                num_input_tokens = round_up(num_scheduled_tokens, tp_size)
            else:
                num_input_tokens = num_scheduled_tokens
1327

1328
        # Padding for DP
1329
1330
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        num_input_tokens += num_pad
1331

1332
1333
1334
1335
1336
1337
1338
1339
1340
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

1341
        if self.is_multimodal_model and get_pp_group().is_first_rank:
1342
1343
1344
1345
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
1346
            if mm_embeds:
1347
                inputs_embeds = self.model.get_input_embeddings(
1348
                    input_ids, mm_embeds)
1349
1350
1351
1352
1353
1354
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
1355
        else:
1356
1357
1358
1359
1360
1361
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
1362
1363
1364
1365
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
1366

1367
1368
1369
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1370
1371
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                num_input_tokens, intermediate_tensors, True)
1372

1373
1374
1375
1376
1377
        # Some attention backends only support CUDA Graphs in pure decode.
        # If attention doesn't support CUDA Graphs for this batch, but we
        # compiled with full CUDA graphs, we have to skip them entirely.
        skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs

1378
        # Run the model.
1379
        # Use persistent buffers for CUDA graphs.
1380
1381
1382
1383
1384
1385
1386
        with set_forward_context(
                attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                skip_cuda_graphs=skip_cuda_graphs,
        ):
Robert Shaw's avatar
Robert Shaw committed
1387
1388
1389
            self.maybe_setup_kv_connector(scheduler_output)

            model_output = self.model(
1390
                input_ids=input_ids,
1391
                positions=positions,
1392
                intermediate_tensors=intermediate_tensors,
1393
                inputs_embeds=inputs_embeds,
1394
            )
1395

Robert Shaw's avatar
Robert Shaw committed
1396
1397
1398
1399
            self.maybe_wait_for_kv_save()
            finished_sending, finished_recving = (
                self.get_finished_kv_transfers(scheduler_output))

1400
        if self.use_aux_hidden_state_outputs:
Robert Shaw's avatar
Robert Shaw committed
1401
            hidden_states, aux_hidden_states = model_output
1402
        else:
Robert Shaw's avatar
Robert Shaw committed
1403
            hidden_states = model_output
1404
1405
            aux_hidden_states = None

1406
1407
1408
1409
1410
1411
1412
        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        broadcast_pp_output = \
            self.parallel_config.distributed_executor_backend \
            == "external_launcher" and len(get_pp_group().ranks) > 0
1413
        if not get_pp_group().is_last_rank:
1414
            # For mid-pipeline stages, return the hidden states.
1415
1416
1417
1418
1419
1420
1421
            if not broadcast_pp_output:
                return hidden_states
            assert isinstance(hidden_states, IntermediateTensors)
            get_pp_group().send_tensor_dict(hidden_states.tensors,
                                            all_gather_group=get_tp_group())
            logits = None
        else:
1422
1423
1424
1425
1426
            if self.input_batch.pooling_params:
                return self._pool(hidden_states, num_scheduled_tokens,
                                  num_scheduled_tokens_np, finished_sending,
                                  finished_recving)

1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states, None)
        if broadcast_pp_output:
            model_output_broadcast_data = {
                "logits": logits.contiguous(),
            } if logits is not None else {}
            model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
                model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
            assert model_output_broadcast_data is not None
            logits = model_output_broadcast_data["logits"]
1437

1438
1439
1440
1441
        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

1442
        # Sample the next token and get logprobs if needed.
1443
        sampling_metadata = self.input_batch.sampling_metadata
1444
        if spec_decode_metadata is None:
1445
            sampler_output = self.sampler(
1446
1447
1448
1449
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
1450
1451
1452
1453
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
1454
            assert logits is not None
1455
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
1456
            sampler_output = self.sampler(
1457
                logits=bonus_logits,
1458
1459
1460
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids
1461

1462
1463
1464
            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
1465
            target_logits = logits[spec_decode_metadata.target_logits_indices]
1466
            output_token_ids = self.rejection_sampler(
1467
                spec_decode_metadata,
1468
1469
                self.draft_probs.get_probs(self.input_batch.req_ids) \
                    if self.draft_probs is not None else None,  # draft_probs
1470
                target_logits,
1471
                bonus_token_ids,
1472
1473
                sampling_metadata,
            )
1474
            sampler_output.sampled_token_ids = output_token_ids
1475

1476
1477
1478
1479
        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

1480
1481
        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
1482
1483
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
1484
1485
1486
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
1487
            if seq_len < req_state.num_tokens:
1488
                # Ignore the sampled token for partial prefills.
1489
                # Rewind the generator state as if the token was not sampled.
1490
                # This relies on cuda-specific torch-internal impl details
1491
1492
1493
1494
1495
1496
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)
1497

1498
1499
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
1500
1501
1502
1503
1504
1505
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
1506
            hidden_states[:num_scheduled_tokens],
1507
1508
1509
            scheduler_output,
        )

1510
        # Get the valid generated tokens.
1511
1512
1513
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
1514
            # No spec decode tokens.
1515
1516
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
1517
            # Includes spec decode tokens.
1518
            valid_sampled_token_ids = self.rejection_sampler.parse_output(
1519
1520
1521
                sampled_token_ids,
                self.input_batch.vocab_size,
            )
1522
1523
1524
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()
1525

1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
        for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
            if not sampled_ids:
                continue

            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
            end_idx = start_idx + len(sampled_ids)
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
                f"{self.max_model_len}")

            self.input_batch.token_ids_cpu[req_idx,
                                           start_idx:end_idx] = sampled_ids
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
            self.input_batch.num_tokens[req_idx] = end_idx
            req_id = self.input_batch.req_ids[req_idx]
            req_state = self.requests[req_id]
            req_state.output_token_ids.extend(sampled_ids)

1550
        if not self.speculative_config:
1551
            # Speculative decoding is not enabled.
1552
            spec_token_ids = None
1553
        else:
1554
            spec_token_ids, draft_probs = self.propose_draft_token_ids(
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
                scheduler_output,
                valid_sampled_token_ids,
                sampling_metadata,
                hidden_states,
                sample_hidden_states,
                aux_hidden_states,
                spec_decode_metadata,
                attn_metadata,
            )

1565
1566
1567
1568
1569
1570
1571
1572
1573
            if self.use_mtp:
                if self.draft_probs is None:
                    self.draft_probs = DraftProbs(
                    draft_probs, self.input_batch.req_ids)
                else:
                    self.draft_probs.update(draft_probs, self.input_batch.req_ids)

            spec_token_ids = spec_token_ids.tolist()

1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
        # Clear KVConnector state after all KVs are generated.
        if has_kv_transfer_group():
            get_kv_transfer_group().clear_connector_metadata()

        self.eplb_step()

        return ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=valid_sampled_token_ids,
            spec_token_ids=spec_token_ids,
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
            pooler_output=[],
            finished_sending=finished_sending,
            finished_recving=finished_recving,
1590
            num_nans_in_logits=num_nans_in_logits
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
        )

    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
        sampled_token_ids: list[list[int]],
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
        aux_hidden_states: Optional[torch.Tensor],
        spec_decode_metadata: Optional[SpecDecodeMetadata],
        attn_metadata: dict[str, Any],
    ) -> list[list[int]]:
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if self.speculative_config.method == "ngram":
1606
            assert isinstance(self.drafter, NgramProposer)
1607
1608
            spec_token_ids = self.propose_ngram_draft_token_ids(
                sampled_token_ids)
1609
1610
        elif self.speculative_config.method == "medusa":
            assert isinstance(self.drafter, MedusaProposer)
1611
1612
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
1613
1614
1615
1616
1617
1618
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
                for num_draft, tokens in zip(
                        spec_decode_metadata.num_draft_tokens,
1619
                        sampled_token_ids):
1620
1621
                    indices.append(offset + len(tokens) - 1)
                    offset += num_draft + 1
1622
                indices = torch.tensor(indices, device=self.device)
1623
1624
1625
1626
1627
1628
                hidden_states = sample_hidden_states[indices]

            spec_token_ids = self.drafter.propose(
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
            )
1629
        elif self.speculative_config.use_eagle():
1630
1631
1632
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
            next_token_ids: list[int] = []
1633
            for i, token_ids in enumerate(sampled_token_ids):
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
                if token_ids:
                    # Common case.
                    next_token_id = token_ids[-1]
                else:
                    # Partial prefill (rare case).
                    # Get the next token id from the request state.
                    req_id = self.input_batch.req_ids[i]
                    req_state = self.requests[req_id]
                    seq_len = (req_state.num_computed_tokens +
                               scheduler_output.num_scheduled_tokens[req_id])
                    next_token_id = req_state.get_token_id(seq_len)
                next_token_ids.append(next_token_id)
1646
1647
1648
1649
1650
1651
1652
            next_token_ids = torch.tensor(next_token_ids,
                                          dtype=torch.int32,
                                          device=self.device)
            # At this moment, we assume all eagle layers belong to the same KV
            # cache group, thus using the same attention metadata.
            eagle_attn_metadata = attn_metadata[
                self.drafter.attn_layer_names[0]]
1653

Jiayi Yao's avatar
Jiayi Yao committed
1654
1655
1656
1657
1658
1659
            # NOTE: deepseek_mtp uses MLA which does not have `block_table`
            if hasattr(eagle_attn_metadata, "block_table"):
                block_table = eagle_attn_metadata.block_table
            else:
                block_table = None

1660
            num_rejected_tokens = None
1661
1662
1663
            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
                target_token_ids = self.input_ids[:num_scheduled_tokens]
1664
1665
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[:num_scheduled_tokens]
1666
                if self.use_aux_hidden_state_outputs:
1667
1668
1669
                    target_hidden_states = torch.cat(
                        [h[:num_scheduled_tokens] for h in aux_hidden_states],
                        dim=-1)
1670
1671
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
1672
1673
                target_slot_mapping = eagle_attn_metadata.slot_mapping
                cu_num_tokens = eagle_attn_metadata.query_start_loc
1674
1675
1676
1677
            else:
                # TODO(woosuk): Refactor this.
                num_draft_tokens = spec_decode_metadata.num_draft_tokens
                num_rejected_tokens = [
1678
                    n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
1679
1680
                    for i, n in enumerate(num_draft_tokens)
                ]
1681
                num_rejected_tokens_tensor = async_tensor_h2d(
1682
1683
                    num_rejected_tokens,
                    dtype=torch.int32,
1684
1685
1686
                    target_device=self.device,
                    pin_memory=True)
                num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
1687
                cu_num_tokens, token_indices = self.drafter.prepare_inputs(
1688
                    eagle_attn_metadata.query_start_loc,
1689
1690
                    num_rejected_tokens_tensor,
                    num_tokens,
1691
1692
                )
                target_token_ids = self.input_ids[token_indices]
1693
1694
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[token_indices]
1695
                if self.use_aux_hidden_state_outputs:
1696
1697
                    target_hidden_states = torch.cat(
                        [h[token_indices] for h in aux_hidden_states], dim=-1)
1698
1699
                else:
                    target_hidden_states = hidden_states[token_indices]
1700
1701
                target_slot_mapping = eagle_attn_metadata.slot_mapping[
                    token_indices]
1702
            spec_token_ids, draft_probs = self.drafter.propose(
1703
1704
1705
1706
1707
1708
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                target_slot_mapping=target_slot_mapping,
                next_token_ids=next_token_ids,
                cu_num_tokens=cu_num_tokens,
Jiayi Yao's avatar
Jiayi Yao committed
1709
                block_table=block_table,
1710
                sampling_metadata=sampling_metadata,
1711
                num_rejected_tokens=num_rejected_tokens
1712
            )
1713
1714
            
        return spec_token_ids, draft_probs
1715

Robert Shaw's avatar
Robert Shaw committed
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
    def kv_connector_no_forward(
            self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
        # KV send/recv even if no work to do.
        with set_forward_context(None, self.vllm_config):
            self.maybe_setup_kv_connector(scheduler_output)
            finished_sending, finished_recving = (
                self.get_finished_kv_transfers(scheduler_output))

        if not finished_sending and not finished_recving:
            return EMPTY_MODEL_RUNNER_OUTPUT

        output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
        output.finished_sending = finished_sending
        output.finished_recving = finished_recving
        return output

    @staticmethod
    def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
        # Update KVConnector with the KVConnector metadata forward().
        if has_kv_transfer_group():
            kv_connector = get_kv_transfer_group()
            assert isinstance(kv_connector, KVConnectorBase_V1)
            assert scheduler_output.kv_connector_metadata is not None
            kv_connector.bind_connector_metadata(
                scheduler_output.kv_connector_metadata)

            # Background KV cache transfers happen here.
            # These transfers are designed to be async and the requests
            # involved may be disjoint from the running requests.
            # Do this here to save a collective_rpc.
            kv_connector.start_load_kv(get_forward_context())

    @staticmethod
    def maybe_wait_for_kv_save() -> None:
        if has_kv_transfer_group():
            get_kv_transfer_group().wait_for_save()

    @staticmethod
    def get_finished_kv_transfers(
        scheduler_output: "SchedulerOutput",
    ) -> tuple[Optional[set[str]], Optional[set[str]]]:
        if has_kv_transfer_group():
            return get_kv_transfer_group().get_finished(
                scheduler_output.finished_req_ids)
        return None, None

1762
    def propose_ngram_draft_token_ids(
1763
        self,
1764
1765
        sampled_token_ids: list[list[int]],
    ) -> list[list[int]]:
1766
        # TODO(woosuk): Optimize.
1767
        draft_token_ids: list[list[int]] = []
1768
1769
1770
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1771
1772
1773
1774
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

1775
1776
            # Skip requests that require sampling parameters that are not
            # supported with speculative decoding.
1777
            req_id = self.input_batch.req_ids[i]
1778
            if req_id in self.input_batch.spec_decode_unsupported_reqs:
1779
1780
1781
                draft_token_ids.append([])
                continue

1782
1783
            num_tokens = self.input_batch.num_tokens_no_spec[i]
            if num_tokens >= self.max_model_len:
1784
1785
1786
1787
                # Skip requests that have already reached the max model length.
                draft_token_ids.append([])
                continue

1788
            drafter_output = self.drafter.propose(
1789
                self.input_batch.token_ids_cpu[i, :num_tokens])
1790
1791
1792
1793
1794
1795
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1796
1797
1798
    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
1799
            time_before_load = time.perf_counter()
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
            model_loader = get_model_loader(self.load_config)
            if not hasattr(self, "model"):
                logger.info("Loading model from scratch...")
                self.model = model_loader.load_model(
                    vllm_config=self.vllm_config,
                    model_config=self.model_config)
            else:
                logger.info(
                    "Model was already initialized. Loading weights inplace..."
                )
                model_loader.load_weights(self.model,
                                          model_config=self.model_config)
1812
1813
            if has_step_pooler(self.model):
                self.input_batch.logits_processing_needs_token_ids = True
1814
1815
1816
1817
1818
1819
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1820
1821
1822
            if hasattr(self, "drafter"):
                logger.info("Loading drafter model...")
                self.drafter.load_model(self.model)
1823
1824
1825
            if self.use_aux_hidden_state_outputs:
                self.model.set_aux_hidden_state_layers(
                    self.model.get_eagle3_aux_hidden_state_layers())
1826
            time_after_load = time.perf_counter()
1827
        self.model_memory_usage = m.consumed_memory
1828
1829
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
1830
                    time_after_load - time_before_load)
1831
        prepare_communication_buffer_for_model(self.model)
1832

1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
        if is_mixture_of_experts(
                self.model) and self.parallel_config.enable_eplb:
            logger.info("EPLB is enabled for model %s.",
                        self.model_config.model)
            self.eplb_state = EplbState.build(
                self.model,
                self.device,
                self.parallel_config,
            )

1843
1844
1845
1846
1847
1848
1849
1850
1851
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

1852
1853
1854
1855
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
1856
    ) -> dict[str, Optional[LogprobsTensors]]:
1857
1858
1859
1860
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

1861
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
1862
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

1877
1878
1879
1880
1881
1882
1883
1884
1885
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
                    num_prompt_tokens - 1, num_prompt_logprobs + 1)
                in_progress_dict[req_id] = logprobs_tensors

1886
            # Determine number of logits to retrieve.
1887
1888
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
1889
            num_remaining_tokens = num_prompt_tokens - start_tok
1890
            if num_tokens <= num_remaining_tokens:
1891
                # This is a chunk, more tokens remain.
1892
1893
1894
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
1895
1896
1897
1898
1899
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
1900
1901
1902
1903
1904
1905
1906
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
1922
1923
            logprobs = self.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.sampler.gather_logprobs(
1924
1925
1926
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
1927
1928
1929
1930
1931
1932
1933
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
                token_ids, non_blocking=True)
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
                                                         non_blocking=True)
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
                ranks, non_blocking=True)
1934
1935
1936
1937
1938

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
1939
            del in_progress_dict[req_id]
1940
1941

        # Must synchronize the non-blocking GPU->CPU transfers.
1942
        if prompt_logprobs_dict:
1943
            self._sync_device()
1944
1945
1946

        return prompt_logprobs_dict

1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
    def _get_nans_in_logits(
        self,
        logits: Optional[torch.Tensor],
    ) -> dict[str, int]:
        try:
            if logits is None:
                return {req_id: 0 for req_id in self.input_batch.req_ids}

            num_nans_in_logits = {}
            num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
            for req_id in self.input_batch.req_ids:
                req_index = self.input_batch.req_id_to_index[req_id]
                num_nans_in_logits[req_id] = (
                    int(num_nans_for_index[req_index])
                    if num_nans_for_index is not None
                    and req_index < logits.shape[0] else 0)
            return num_nans_in_logits
        except IndexError:
            return {}

1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
    @contextmanager
    def maybe_randomize_inputs(self, input_ids: torch.Tensor):
        """
        Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
        This is to help balance expert-selection
         - during profile_run
         - during DP rank dummy run 
        """
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
        if not randomize_inputs:
            yield
        else:
            import functools

            @functools.cache
            def rand_input_ids() -> torch.Tensor:
                return torch.randint_like(
                    self.input_ids,
                    low=0,
                    high=self.model_config.get_vocab_size(),
                    dtype=input_ids.dtype)

            logger.debug("Randomizing dummy data for DP Rank")
            input_ids.copy_(rand_input_ids()[:input_ids.size(0)],
                            non_blocking=True)
            yield
            input_ids.fill_(0)

1996
1997
1998
1999
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
2000
        capture_attn_cudagraph: bool = False,
2001
2002
        skip_eplb: bool = False,
        is_profile: bool = False,
2003
    ) -> tuple[torch.Tensor, torch.Tensor]:
2004

2005
        # Padding for DP
2006
2007
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
        num_tokens += num_pad
2008

2009
2010
2011
2012
2013
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
2014
        num_reqs = min(num_tokens, max_num_reqs)
2015
        min_tokens_per_req = num_tokens // num_reqs
王敏's avatar
王敏 committed
2016
2017
2018
2019

        if not is_profile and self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
            min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
            num_reqs = num_tokens // min_tokens_per_req
2020
2021
2022
2023
2024
2025
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
2026

2027
2028
2029
2030
        attn_metadata: Optional[dict[str, Any]] = None
        if capture_attn_cudagraph:
            attn_metadata = {}

2031
            query_start_loc = self.query_start_loc[:num_reqs + 1]
2032
2033
2034
2035
2036
            # Make sure max_model_len is used at the graph capture time.
            self.seq_lens_np[:num_reqs] = self.max_model_len
            self.seq_lens_np[num_reqs:] = 0
            self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
                                           non_blocking=True)
2037
2038
            seq_lens = self.seq_lens[:num_reqs]

王敏's avatar
王敏 committed
2039
            num_speculative_tokens = 0 if self.speculative_config is None else self.speculative_config.num_lookahead_slots
2040
            common_attn_metadata = CommonAttentionMetadata(
2041
2042
2043
2044
2045
                query_start_loc=query_start_loc,
                seq_lens=seq_lens,
                num_reqs=num_reqs,
                num_actual_tokens=num_tokens,
                max_query_len=num_tokens,
王敏's avatar
王敏 committed
2046
                num_speculative_tokens=num_speculative_tokens,
2047
            )
2048

2049
2050
            for kv_cache_group_id, kv_cache_group_spec in enumerate(
                    self.kv_cache_config.kv_cache_groups):
2051
2052
2053
2054

                attn_metadata_i = self.attn_metadata_builders[
                    kv_cache_group_id].build_for_cudagraph_capture(
                        common_attn_metadata)
2055
2056
                for layer_name in kv_cache_group_spec.layer_names:
                    attn_metadata[layer_name] = attn_metadata_i
2057

2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
        with self.maybe_dummy_run_with_lora(self.lora_config,
                                            num_scheduled_tokens):
            model = self.model
            if self.is_multimodal_model:
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_tokens]
            else:
                positions = self.positions[:num_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
                            device=self.device))
2081
2082
2083

                intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                    num_tokens, None, False)
2084

2085
            with self.maybe_randomize_inputs(input_ids), set_forward_context(
2086
2087
2088
2089
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=num_tokens,
                    num_tokens_across_dp=num_tokens_across_dp):
2090
                outputs = model(
2091
2092
2093
2094
2095
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )
2096
2097
2098
2099
            if self.use_aux_hidden_state_outputs:
                hidden_states, _ = outputs
            else:
                hidden_states = outputs
2100

2101
            if self.speculative_config and self.speculative_config.use_eagle():
2102
                assert isinstance(self.drafter, EagleProposer)
2103
                self.drafter.dummy_run(num_tokens, attn_metadata)
2104

2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
        # This is necessary to avoid blocking DP.
        # For dummy runs, we typically skip EPLB since we don't have any real
        # requests to process.
        # However, in DP settings, there may be cases when some DP ranks do
        # not have any requests to process, so they're executing dummy batches.
        # In such cases, we still have to trigger EPLB to make sure
        # ranks execute the rearrangement in synchronization.
        if not skip_eplb:
            self.eplb_step(is_dummy=True, is_profile=is_profile)

2115
        logit_indices = np.cumsum(num_scheduled_tokens) - 1
2116
        return hidden_states, hidden_states[logit_indices]
2117
2118
2119
2120
2121
2122

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
2123
2124
2125
2126
        # The dummy hidden states may contain special values,
        # like `inf` or `nan`.
        # To avoid breaking the sampler, we use a random tensor here instead.
        hidden_states = torch.rand_like(hidden_states)
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149

        logits = self.model.compute_logits(hidden_states, None)
        num_reqs = logits.size(0)

        dummy_tensors = lambda v: torch.full(
            (num_reqs, ), v, device=self.device)

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
2150
            logitsprocs=LogitsProcessorManager(),
2151
        )
2152
        try:
2153
2154
            sampler_output = self.sampler(logits=logits,
                                          sampling_metadata=dummy_metadata)
2155
2156
2157
2158
2159
2160
2161
2162
2163
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
2164
        if self.speculative_config:
2165
2166
2167
2168
2169
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
                draft_token_ids, self.device)

            num_tokens = sum(len(ids) for ids in draft_token_ids)
2170
2171
2172
2173
            draft_probs = torch.randn(
                num_tokens, logits.shape[-1], device=self.device,
                dtype=logits.dtype)
            # draft_probs = None
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
            target_logits = torch.randn(num_tokens,
                                        logits.shape[-1],
                                        device=self.device,
                                        dtype=logits.dtype)
            # NOTE(woosuk): Here, we should use int32 because the sampler uses
            # int32 for bonus_token_ids. If the dtype mismatches, re-compilation
            # will occur at runtime.
            bonus_token_ids = torch.zeros(num_reqs,
                                          device=self.device,
                                          dtype=torch.int32)
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
                target_logits,
                bonus_token_ids,
                dummy_metadata,
            )
2191
        return sampler_output
2192

2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
    @torch.inference_mode()
    def _dummy_pooler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:

        num_tokens = hidden_states.shape[0]
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = min(num_tokens, max_num_reqs)
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs

        hidden_states_list = list(
            torch.split(hidden_states, num_scheduled_tokens_list))

        req_num_tokens = num_tokens // num_reqs

        dummy_metadata = PoolingMetadata(
            prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
                                     device=self.device),
            prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
                                         dtype=torch.int32,
                                         device=self.device),
            pooling_params=[PoolingParams()] * num_reqs)

        try:
            pooler_output = self.model.pooler(hidden_states=hidden_states_list,
                                              pooling_metadata=dummy_metadata)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up pooler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
        return pooler_output

2235
    def profile_run(self) -> None:
2236
        # Profile with multimodal encoder & encoder cache.
2237
2238
2239
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
2240

2241
            # NOTE: Currently model is profiled with a single non-text
2242
2243
            # modality with the max possible input tokens even when
            # it supports multiple.
2244
2245
            max_tokens_by_modality_dict = self.mm_registry \
                .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
2246
2247
2248
2249
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
2250
2251
2252
2253
2254
2255
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
2256
2257
2258

            # Check how many items of this modality can be supported by
            # the decoder budget.
2259
2260
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

2271
2272
2273
2274
2275
2276
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
2277
            dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
2278
2279
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
2280
2281
2282
2283
                mm_counts={
                    dummy_data_modality: 1
                },
            ).multi_modal_data
2284

2285
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
2286
2287
                [dummy_mm_kwargs] * max_num_mm_items,
                pin_memory=self.pin_memory)
2288
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
2289
2290
2291
                batched_dummy_mm_inputs,
                device=self.device,
            )
2292
2293
2294
2295

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
2296
2297
2298
2299
2300

            sanity_check_mm_encoder_outputs(
                dummy_encoder_outputs,
                expected_num_items=max_num_mm_items,
            )
2301
2302
2303
2304

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

2305
        # Add `is_profile` here to pre-allocate communication buffers
2306
        hidden_states, last_hidden_states \
2307
            = self._dummy_run(self.max_num_tokens, is_profile=True)
2308
        if get_pp_group().is_last_rank:
2309
2310
2311
2312
            if self.is_pooling_model:
                output = self._dummy_pooler_run(hidden_states)
            else:
                output = self._dummy_sampler_run(last_hidden_states)
2313
        else:
2314
            output = None
2315
        self._sync_device()
2316
        del hidden_states, output
2317
        self.encoder_cache.clear()
2318
        gc.collect()
2319
2320

    def capture_model(self) -> None:
2321
2322
        if not self.use_cuda_graph:
            logger.warning(
2323
2324
2325
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "set -O %s and ensure `use_cudagraph` was not manually set to "
                "False", CompilationLevel.PIECEWISE)
2326
2327
            return

2328
2329
        compilation_counter.num_gpu_runner_capture_triggers += 1

2330
2331
2332
        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

2333
2334
2335
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
2336
        with graph_capture(device=self.device):
2337
            full_cg = self.full_cuda_graph
2338
2339
2340
2341
2342
2343
            # Only rank 0 should print progress bar during capture
            compilation_cases = reversed(self.cudagraph_batch_sizes)
            if is_global_first_rank():
                compilation_cases = tqdm(list(compilation_cases),
                                         desc="Capturing CUDA graph shapes")
            for num_tokens in compilation_cases:
2344
                # We skip EPLB here since we don't want to record dummy metrics
2345
2346
                for _ in range(
                        self.compilation_config.cudagraph_num_of_warmups):
2347
2348
2349
2350
2351
2352
                    self._dummy_run(num_tokens,
                                    capture_attn_cudagraph=full_cg,
                                    skip_eplb=True)
                self._dummy_run(num_tokens,
                                capture_attn_cudagraph=full_cg,
                                skip_eplb=True)
2353
2354
2355
2356
2357
2358
2359
2360

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
2361

2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
    def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize the attention backends and attention metadata builders.
        """
        assert len(self.attn_backends) == 0 and len(
            self.attn_metadata_builders
        ) == 0, "Attention backends are already initialized"
        for i, kv_cache_group_spec in enumerate(
                kv_cache_config.kv_cache_groups):
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
Chen Zhang's avatar
Chen Zhang committed
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
            if isinstance(kv_cache_spec, AttentionSpec):
                attn_backend_i = get_attn_backend(
                    kv_cache_spec.head_size,
                    self.dtype,
                    kv_cache_spec.dtype,
                    kv_cache_spec.block_size,
                    self.model_config.is_attention_free,
                    use_mla=kv_cache_spec.use_mla,
                )
                if attn_backend_i is None:
                    error_msg = (f"Error with get_attn_backend: "
                                 f"{kv_cache_spec.head_size=}, "
                                 f"{self.dtype=}, {kv_cache_spec.dtype=}, "
                                 f"{kv_cache_spec.block_size=}, "
                                 f"{self.model_config.is_attention_free=}, "
                                 f"{kv_cache_spec.use_mla=}")
                    logger.error(error_msg)
                    raise NotImplementedError(
                        "Non-Attention backend is not supported by V1 "
                        "GPUModelRunner.")
            elif isinstance(kv_cache_spec, MambaSpec):
                attn_backend_i = Mamba2AttentionBackend
            else:
                raise ValueError(
                    f"Unknown KV cache spec type: {type(kv_cache_spec)}")
2397
2398
2399

            block_table_i = self.input_batch.block_table[i]
            attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
2400
2401
2402
2403
2404
                weakref.proxy(self),
                kv_cache_spec,
                block_table_i,
            )

zhuwenwen's avatar
zhuwenwen committed
2405
            if (not current_platform.is_rocm() and self.full_cuda_graph
2406
2407
2408
2409
2410
2411
                    and not attn_metadata_builder_i.full_cudagraph_supported):
                raise ValueError(
                    f"Full CUDAGraph not supported for "
                    f"{attn_backend_i.__name__}. Turn off CompilationConfig."
                    f"full_cuda_graph or use a different attention backend.")

2412
2413
2414
            self.attn_backends.append(attn_backend_i)
            self.attn_metadata_builders.append(attn_metadata_builder_i)

2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
    def may_reinitialize_input_batch(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Re-initialize the input batch if the block sizes are different from
        `[self.cache_config.block_size]`. This usually happens when there
        are multiple KV cache groups.

        Args:
            kv_cache_config: The KV cache configuration.
        """
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]
        if block_sizes != [self.cache_config.block_size]:
            assert self.cache_config.cpu_offload_gb == 0, (
                "Cannot re-initialize the input batch when CPU weight "
                "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 "  # noqa: E501
                "for more details.")
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
                max_model_len=self.max_model_len,
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
                block_sizes=block_sizes,
2442
                is_spec_decode=bool(self.vllm_config.speculative_config),
2443
2444
            )

2445
2446
    def _allocate_kv_cache_tensors(
            self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
2447
        """
2448
2449
2450
        Initializes the KV cache buffer with the correct size. The buffer needs
        to be reshaped to the desired shape before being used by the models.

2451
        Args:
2452
            kv_cache_config: The KV cache config
2453
        Returns:
2454
            dict[str, torch.Tensor]: A map between layer names to their
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
            corresponding memory buffer for KV cache.
         """
        kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            tensor = torch.zeros(kv_cache_tensor.size,
                                 dtype=torch.int8,
                                 device=self.device)
            for layer_name in kv_cache_tensor.shared_by:
                kv_cache_raw_tensors[layer_name] = tensor

        layer_names = set()
        for group in kv_cache_config.kv_cache_groups:
            layer_names.update(group.layer_names)
        assert layer_names == set(kv_cache_raw_tensors.keys(
        )), "Some layers are not correctly initialized"
        return kv_cache_raw_tensors

    def _reshape_kv_cache_tensors(
        self,
        kv_cache_config: KVCacheConfig,
        kv_cache_raw_tensors: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
2477
        """
2478
        Reshape the KV cache tensors to the desired shape and dtype.
2479

2480
        Args:
2481
2482
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer, with
2483
2484
            correct size but uninitialized shape.
        Returns:
2485
            Dict[str, torch.Tensor]: A map between layer names to their
2486
2487
            corresponding memory buffer for KV cache.
        """
2488
        kv_caches: dict[str, torch.Tensor] = {}
2489
        has_attn, has_mamba = False, False
2490
2491
2492
2493
2494
2495
2496
2497
        for i, kv_cache_group_spec in enumerate(
                kv_cache_config.kv_cache_groups):
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
            for layer_name in kv_cache_group_spec.layer_names:
                raw_tensor = kv_cache_raw_tensors[layer_name]
                assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
                num_blocks = (raw_tensor.numel() //
                              kv_cache_spec.page_size_bytes)
2498
                if isinstance(kv_cache_spec, AttentionSpec):
2499
                    has_attn = True
2500
                    kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
2501
2502
2503
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
                    try:
                        kv_cache_stride_order = self.attn_backends[
                            i].get_kv_cache_stride_order()
                        assert len(kv_cache_stride_order) == len(
                            kv_cache_shape)
                    except (AttributeError, NotImplementedError):
                        kv_cache_stride_order = tuple(
                            range(len(kv_cache_shape)))
                    # The allocation respects the backend-defined stride order
                    # to ensure the semantic remains consistent for each
                    # backend. We first obtain the generic kv cache shape and
                    # then permute it according to the stride order which could
                    # result in a non-contiguous tensor.
                    kv_cache_shape = tuple(kv_cache_shape[i]
                                           for i in kv_cache_stride_order)
                    # Maintain original KV shape view.
                    inv_order = [
                        kv_cache_stride_order.index(i)
                        for i in range(len(kv_cache_stride_order))
                    ]
2524
2525
2526
                    kv_caches[layer_name] = kv_cache_raw_tensors[
                        layer_name].view(dtype).view(kv_cache_shape).permute(
                            *inv_order)
Chen Zhang's avatar
Chen Zhang committed
2527
                elif isinstance(kv_cache_spec, MambaSpec):
2528
                    has_mamba = True
Chen Zhang's avatar
Chen Zhang committed
2529
2530
                    raw_tensor = kv_cache_raw_tensors[layer_name]
                    dtype = kv_cache_spec.dtype
2531
2532
                    num_element_per_page = (kv_cache_spec.page_size_bytes //
                                            get_dtype_size(dtype))
Chen Zhang's avatar
Chen Zhang committed
2533
                    state_tensors = []
2534
                    storage_offset = 0
Chen Zhang's avatar
Chen Zhang committed
2535
2536
                    for shape in kv_cache_spec.shapes:
                        target_shape = (num_blocks, *shape)
2537
2538
2539
2540
2541
2542
2543
2544
                        stride = torch.empty(target_shape).stride()
                        target_stride = (num_element_per_page, *stride[1:])
                        tensor = torch.as_strided(
                            raw_tensor.view(dtype),
                            size=target_shape,
                            stride=target_stride,
                            storage_offset=storage_offset,
                        )
Chen Zhang's avatar
Chen Zhang committed
2545
                        state_tensors.append(tensor)
2546
2547
2548
                        storage_offset += stride[0]

                    kv_caches[layer_name] = state_tensors
2549
                else:
2550
                    raise NotImplementedError
2551
2552
2553
2554
2555

        if has_attn and has_mamba:
            self._verify_hybrid_attention_mamba_layout(kv_cache_config,
                                                       kv_cache_raw_tensors)

2556
2557
        return kv_caches

2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
    def _verify_hybrid_attention_mamba_layout(
            self, kv_cache_config: KVCacheConfig,
            kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
        """
        Verify that the KV cache memory layout is compatible for
        models with both attention and mamba KV cache groups.

        Args:
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer.
        """

        for i, kv_cache_group_spec in enumerate(
                kv_cache_config.kv_cache_groups):
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
            for layer_name in kv_cache_group_spec.layer_names:
                raw_tensor = kv_cache_raw_tensors[layer_name]
                num_blocks = (raw_tensor.numel() //
                              kv_cache_spec.page_size_bytes)
                if isinstance(kv_cache_spec, AttentionSpec):
                    kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    if kv_cache_shape[0] != num_blocks or kv_cache_shape[
                            1] != 2:
                        raise ValueError(
                            "Hybrid models in V1 require an attention "
                            "backend with kv_cache_shape="
                            "(num_blocks, 2, ...). Please try setting "
                            "VLLM_ATTENTION_BACKEND=FLASHINFER")

2589
2590
2591
2592
2593
2594
2595
2596
    def initialize_kv_cache_tensors(
            self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
        """
        Initialize the memory buffer for KV cache.

        Args:
            kv_cache_config: The KV cache config
        Returns:
2597
            Dict[str, torch.Tensor]: A map between layer names to their
2598
2599
2600
2601
2602
2603
2604
            corresponding memory buffer for KV cache.
        """
        # Initialize the memory buffer for KV cache
        kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
        # Change the memory buffer to the desired shape
        kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
                                                   kv_cache_raw_tensors)
2605

2606
2607
2608
2609
2610
2611
2612
2613
2614
        # Setup `kv_cache_config` and `kv_caches` for models
        # with cross-layer KV sharing
        if self.shared_kv_cache_layers:
            initialize_kv_cache_for_kv_sharing(
                self.shared_kv_cache_layers,
                kv_cache_config.kv_cache_groups,
                kv_caches,
            )

2615
2616
2617
        bind_kv_cache(kv_caches,
                      self.compilation_config.static_forward_context,
                      self.kv_caches)
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
        return kv_caches

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV
            cache size of each layer
        """
        self.kv_cache_config = kv_cache_config
        self.may_reinitialize_input_batch(kv_cache_config)
        self.initialize_attn_backend(kv_cache_config)
        kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)

2632
2633
2634
2635
2636
2637
        if self.speculative_config and self.speculative_config.use_eagle():
            assert isinstance(self.drafter, EagleProposer)
            # validate all draft model layers belong to the same kv cache
            # group
            self.drafter.validate_same_kv_cache_group(kv_cache_config)

Robert Shaw's avatar
Robert Shaw committed
2638
2639
2640
        if has_kv_transfer_group():
            get_kv_transfer_group().register_kv_caches(kv_caches)

2641
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
2642
        """
2643
        Generates the KVCacheSpec by parsing the kv cache format from each
2644
2645
        Attention module in the static forward context.
        Returns:
2646
            KVCacheSpec: A dictionary mapping layer names to their KV cache
2647
2648
2649
2650
            format. Layers that do not need KV cache are not included.
        """

        block_size = self.vllm_config.cache_config.block_size
2651
        use_mla = self.vllm_config.model_config.use_mla
2652
        kv_cache_spec: dict[str, KVCacheSpec] = {}
Chen Zhang's avatar
Chen Zhang committed
2653
2654
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
        for layer_name, attn_module in attn_layers.items():
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
            if (kv_tgt_layer :=
                    attn_module.kv_sharing_target_layer_name) is not None:
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                continue

2667
            # TODO: Support other attention modules, e.g., cross-attention
2668
            if attn_module.attn_type == AttentionType.DECODER:
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=use_mla)
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        use_mla=use_mla)
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

Chen Zhang's avatar
Chen Zhang committed
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
        mamba_layers = get_layers_from_vllm_config(self.vllm_config,
                                                   MambaMixer2)
        if len(mamba_layers) > 0:
            if self.vllm_config.speculative_config is not None:
                raise NotImplementedError(
                    "Mamba with speculative decoding is not supported yet.")
            if not self.vllm_config.model_config.enforce_eager:
                raise NotImplementedError(
                    "Mamba with cuda graph is not supported yet.")
            if self.vllm_config.cache_config.enable_prefix_caching:
                raise NotImplementedError(
                    "Prefix caching is not supported for Mamba yet.")
            max_model_len = self.vllm_config.model_config.max_model_len
2707
2708
2709
2710
2711

            page_size_padded = self._maybe_pad_mamba_page_size(
                attn_layers, mamba_layers, kv_cache_spec, max_model_len,
                block_size)

Chen Zhang's avatar
Chen Zhang committed
2712
2713
2714
2715
2716
2717
            # Set block_size to max_model_len, so that mamba model will always
            # have only one block in the KV cache.
            for layer_name, mamba_module in mamba_layers.items():
                kv_cache_spec[layer_name] = MambaSpec(
                    shapes=mamba_module.get_state_shape(),
                    dtype=self.kv_cache_dtype,
2718
2719
2720
                    block_size=max_model_len,
                    page_size_padded=page_size_padded)

2721
        return kv_cache_spec
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772

    def _maybe_pad_mamba_page_size(
        self,
        attn_layers: dict[str, Attention],
        mamba_layers: dict[str, MambaMixer2],
        kv_cache_spec: dict[str, KVCacheSpec],
        max_model_len: int,
        block_size: int,
    ) -> Optional[int]:
        """
        Ensure that page size of attention KV cache groups is greater than or
        equal to the mamba KV cache groups. If not, we suggest to the user
        how to set the attention block size to ensure that it is.

        If the attention page size is strictly greater than the mamba page size,
        we pad the mamba page size to make them equal.

        Args:
            attn_layers: Attention layers
            mamba_layers: Mamba layers
            kv_cache_spec: KV cache spec (populated with attention layers)

        Returns:
            Optional[int]: Mamba page size with padding (None if no padding).
        """

        if len(attn_layers) == 0:
            return None

        attn_layer_name = next(iter(attn_layers))
        attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
        mamba_layer_name = next(iter(mamba_layers))
        mamba_page_size = MambaSpec(
            shapes=mamba_layers[mamba_layer_name].get_state_shape(),
            dtype=self.kv_cache_dtype,
            block_size=max_model_len).page_size_bytes
        if attn_page_size < mamba_page_size:
            # attention page size (for 16 tokens)
            attn_page_size_16 = 16 * attn_page_size // block_size
            # some attention backends (e.g. FA) only support setting
            # block size to multiple of 16, so let's suggest a value
            # that would work (note: FA is currently not compatible
            # with mamba layers, use FlashInfer instead).
            suggest_attn_block_size = 16 * cdiv(mamba_page_size,
                                                attn_page_size_16)
            raise ValueError(
                "Attention block size should be increased to at least "
                f"{suggest_attn_block_size} in order to match "
                "the mamba page size")

        return attn_page_size