gpu_model_runner.py 286 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
import gc
6
import itertools
7
import threading
8
import time
9
from collections import defaultdict
10
from collections.abc import Iterator, Sequence
11
from contextlib import contextmanager
12
from copy import copy, deepcopy
13
from dataclasses import dataclass
14
from functools import reduce
王敏's avatar
王敏 committed
15
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast, Optional
16
17
18
19
20

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
21
from tqdm import tqdm
22

23
import vllm.envs as envs
24
from vllm.attention.layer import Attention, MLAAttention
25
from vllm.compilation.counter import compilation_counter
26
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
27
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
28
from vllm.config import (
29
    CompilationMode,
30
31
32
33
34
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
    update_config,
)
35
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
36
from vllm.distributed.eplb.eplb_state import EplbState
37
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
38
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
39
from vllm.distributed.parallel_state import (
40
    get_dcp_group,
41
42
43
44
    get_pp_group,
    get_tp_group,
    graph_capture,
    is_global_first_rank,
45
46
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
47
48
    prepare_communication_buffer_for_model,
)
49
50
51
from vllm.distributed import (
    tensor_model_parallel_all_gather
)
52
53
54
55
from vllm.forward_context import (
    BatchDescriptor,
    set_forward_context,
)
56
from vllm.logger import init_logger
57
from vllm.lora.layers import LoRAMapping, LoRAMappingType
58
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
59
60
61
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
    RoutedExpertsCapturer,
)
62
63
64
65
from vllm.model_executor.layers.rotary_embedding import (
    MRotaryEmbedding,
    XDRotaryEmbedding,
)
66
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
67
from vllm.model_executor.models.interfaces import (
68
    MultiModalEmbeddings,
69
    SupportsMRoPE,
70
    SupportsMultiModal,
71
    SupportsXDRoPE,
72
73
74
75
76
    is_mixture_of_experts,
    supports_eagle3,
    supports_mrope,
    supports_multimodal_pruning,
    supports_transcription,
77
    supports_xdrope,
78
)
79
from vllm.model_executor.models.interfaces_base import (
80
81
82
83
    VllmModelForPooling,
    is_pooling_model,
    is_text_generation_model,
)
84
from vllm.multimodal import MULTIMODAL_REGISTRY
85
86
87
88
89
from vllm.multimodal.inputs import (
    BatchedTensorInputs,
    MultiModalKwargsItem,
    PlaceholderRange,
)
90
from vllm.multimodal.utils import group_mm_kwargs_by_modality
91
from vllm.pooling_params import PoolingParams
92
from vllm.sampling_params import SamplingType
93
from vllm.sequence import IntermediateTensors
94
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
95
from vllm.utils import length_from_prompt_token_ids_or_embeds
96
from vllm.utils.jsontree import json_map_leaves
97
from vllm.utils.math_utils import cdiv, round_up
98
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
99
from vllm.utils.nvtx_pytorch_hooks import PytHooks
100
from vllm.utils.platform_utils import is_pin_memory_available
101
102
103
104
from vllm.utils.torch_utils import (
    get_dtype_size,
    kv_cache_dtype_str_to_dtype,
)
105
106
from vllm.v1.attention.backend import (
    AttentionBackend,
107
    AttentionCGSupport,
108
    AttentionMetadata,
109
    AttentionMetadataBuilder,
110
    AttentionType,
111
    CommonAttentionMetadata,
112
    CpCommonAttentionMetadata,
113
114
    MultipleOf,
)
115
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
116
from vllm.v1.attention.backends.utils import (
117
    create_fast_prefill_custom_backend,
118
    get_dcp_local_seq_lens,
119
120
    reorder_batch_to_split_decodes_and_prefills,
)
121
from vllm.v1.core.sched.output import NewRequestData
122
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from vllm.v1.kv_cache_interface import (
    AttentionSpec,
    ChunkedLocalAttentionSpec,
    CrossAttentionSpec,
    EncoderOnlyAttentionSpec,
    FullAttentionSpec,
    KVCacheConfig,
    KVCacheGroupSpec,
    KVCacheSpec,
    MambaSpec,
    SlidingWindowSpec,
    UniformTypeKVCacheSpecs,
)
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
140
    ECConnectorOutput,
141
    KVConnectorOutput,
142
143
144
145
146
    LogprobsLists,
    LogprobsTensors,
    ModelRunnerOutput,
    PoolerOutput,
    SamplerOutput,
147
    make_empty_encoder_model_runner_output,
148
)
149
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
150
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
151
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
152
from vllm.v1.sample.metadata import SamplingMetadata
153
from vllm.v1.sample.rejection_sampler import RejectionSampler
王敏's avatar
王敏 committed
154
from vllm.v1.sample.rejection_sampler_opt import OptRejectionSampler
155
from vllm.v1.sample.sampler import Sampler
156
from vllm.v1.spec_decode.draft_model import DraftModelProposer
157
from vllm.v1.spec_decode.eagle import EagleProposer
158
from vllm.v1.spec_decode.medusa import MedusaProposer
159
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
160
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
161
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
162
from vllm.v1.structured_output.utils import apply_grammar_bitmask
163
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
164
165
166
167
168
from vllm.v1.worker import mamba_utils
from vllm.v1.worker.cp_utils import (
    check_attention_cp_compatibility,
    get_total_cp_world_size,
)
169
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
170
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
171
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
172
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
173
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
174
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
175
176
177
from vllm.v1.worker.ubatch_utils import (
    UBatchSlices,
    check_ubatch_thresholds,
178
    maybe_create_ubatch_slices,
179
    split_attn_metadata,
180
)
181
from vllm.v1.worker.utils import is_residual_scattered_for_sp
182
from vllm.v1.worker.workspace import lock_workspace
183

184
185
186
187
188
189
190
from .utils import (
    AttentionGroup,
    MultiModalBudget,
    add_kv_sharing_layers_to_kv_cache_groups,
    bind_kv_cache,
    sanity_check_mm_encoder_outputs,
)
王敏's avatar
王敏 committed
191
from vllm.v1.spec_decode.utils import DraftProbs
192
from vllm.utils.torch_utils import async_tensor_h2d
193

194
if TYPE_CHECKING:
195
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
196
    from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
197
198
199

logger = init_logger(__name__)

200
201
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
# list when ubatching is enabled
202
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
203

204

205
206
207
208
209
210
# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
    def __init__(
        self,
        model_runner_output: ModelRunnerOutput,
        sampled_token_ids: torch.Tensor,
211
        logprobs_tensors: LogprobsTensors | None,
212
213
        invalid_req_indices: list[int],
        async_output_copy_stream: torch.cuda.Stream,
214
        vocab_size: int,
215
216
217
218
219
    ):
        self._model_runner_output = model_runner_output
        self._invalid_req_indices = invalid_req_indices

        # Event on the copy stream so we can synchronize the non-blocking copy.
220
        self.async_copy_ready_event = torch.Event()
221
222
223
224

        # Keep a reference to the device tensor to avoid it being
        # deallocated until we finish copying it to the host.
        self._sampled_token_ids = sampled_token_ids
225
        self.vocab_size = vocab_size
226
        self._logprobs_tensors = logprobs_tensors
227
228
229
230
231

        # Initiate the copy on a separate stream, but do not synchronize it.
        default_stream = torch.cuda.current_stream()
        with torch.cuda.stream(async_output_copy_stream):
            async_output_copy_stream.wait_stream(default_stream)
232
            self.sampled_token_ids_cpu = self._sampled_token_ids.to(
233
234
                "cpu", non_blocking=True
            )
235
236
237
238
239
            self._logprobs_tensors_cpu = (
                self._logprobs_tensors.to_cpu_nonblocking()
                if self._logprobs_tensors
                else None
            )
240
            self.async_copy_ready_event.record()
241
242
243

    def get_output(self) -> ModelRunnerOutput:
        """Copy the device tensors to the host and return a ModelRunnerOutput.
244

245
246
        This function blocks until the copy is finished.
        """
247
        max_gen_len = self.sampled_token_ids_cpu.shape[-1]
248
        self.async_copy_ready_event.synchronize()
249

250
251
        # Release the device tensors once the copy has completed.
        del self._logprobs_tensors
252
        del self._sampled_token_ids
253
        if max_gen_len == 1:
254
            valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
255
256
            for i in self._invalid_req_indices:
                valid_sampled_token_ids[i].clear()
257
258
259
            logprobs_lists = None
            if self._logprobs_tensors_cpu is not None:
                logprobs_lists = self._logprobs_tensors_cpu.tolists()
260
        else:
261
            valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
262
263
                self.sampled_token_ids_cpu,
                self.vocab_size,
264
                self._invalid_req_indices,
265
                logprobs_tensors=self._logprobs_tensors_cpu,
266
            )
267
268
269

        output = self._model_runner_output
        output.sampled_token_ids = valid_sampled_token_ids
270
        output.logprobs = logprobs_lists
271
272
        return output

273
274
275
    def get_output_async(self) -> ModelRunnerOutput:
        return self._model_runner_output

276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput):
    def __init__(
        self,
        model_runner_output: ModelRunnerOutput,
        raw_pooler_output: PoolerOutput,
        finished_mask: list[bool],
        async_output_copy_stream: torch.cuda.Stream,
    ):
        self._model_runner_output = model_runner_output

        # Event on the copy stream so we can synchronize the non-blocking copy.
        self.async_copy_ready_event = torch.Event()

        # Keep a reference to the device tensors to avoid them being
        # deallocated until we finish copying it to the host.
        self._raw_pooler_output = raw_pooler_output

        # Initiate the copy on a separate stream, but do not synchronize it.
        default_stream = torch.cuda.current_stream()
        with torch.cuda.stream(async_output_copy_stream):
            async_output_copy_stream.wait_stream(default_stream)
298
            raw_pooler_output_cpu = json_map_leaves(
299
300
301
302
                lambda x: None if x is None else x.to("cpu", non_blocking=True),
                self._raw_pooler_output,
            )
            self.async_copy_ready_event.record()
303
304
305
306
            self._model_runner_output.pooler_output = [
                out if include else None
                for out, include in zip(raw_pooler_output_cpu, finished_mask)
            ]
307
308
309
310
311
312
313
314
315
316
317
318

    def get_output(self) -> ModelRunnerOutput:
        """Copy the device tensors to the host and return a ModelRunnerOutput.
        This function blocks until the copy is finished.
        """
        self.async_copy_ready_event.synchronize()

        # Release the device tensors once the copy has completed.
        del self._raw_pooler_output
        return self._model_runner_output


319
320
321
class ExecuteModelState(NamedTuple):
    """Ephemeral cached state transferred between execute_model() and
    sample_tokens(), after execute_model() returns None."""
322

323
324
325
326
327
328
329
    scheduler_output: "SchedulerOutput"
    logits: torch.Tensor
    spec_decode_metadata: SpecDecodeMetadata | None
    spec_decode_common_attn_metadata: CommonAttentionMetadata | None
    hidden_states: torch.Tensor
    sample_hidden_states: torch.Tensor
    aux_hidden_states: list[torch.Tensor] | None
330
    ec_connector_output: ECConnectorOutput | None
331
    cudagraph_stats: CUDAGraphStat | None
332
    slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None
333
334


335
336
337
class GPUModelRunner(
    LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
338
339
    def __init__(
        self,
340
        vllm_config: VllmConfig,
341
        device: torch.device,
342
    ):
343
344
345
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
346
        self.compilation_config = vllm_config.compilation_config
347
348
349
350
351
352
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
353

354
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
355
356

        set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
357

358
359
360
361
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
362
        self.device = device
363
364
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
365

366
367
368
        self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            cache_config.cache_dtype, self.model_config
        )
369

370
        self.is_pooling_model = model_config.runner_type == "pooling"
371
        self.enable_prompt_embeds = model_config.enable_prompt_embeds
372
        self.is_multimodal_raw_input_only_model = (
373
374
            model_config.is_multimodal_raw_input_only_model
        )
375
376
        # This will be overridden in load_model()
        self.is_multimodal_pruning_enabled = False
377
        self.max_model_len = model_config.max_model_len
378
379
380

        # Always set to false after the first forward pass
        self.calculate_kv_scales = self.cache_config.calculate_kv_scales
381
        self.tp_size = self.parallel_config.tensor_parallel_size
382
        self.dcp_world_size = self.parallel_config.decode_context_parallel_size
383
        self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
384
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
385
386
387
388
389
390
391
392
393
394
        #self.max_num_reqs = scheduler_config.max_num_seqs

        self.enable_lightly_cp = self.parallel_config.enable_lightly_cp
        self.enable_lightly_cplb = self.enable_lightly_cp and self.parallel_config.enable_lightly_cplb
        self.max_num_reqs = (
            scheduler_config.max_num_seqs
            if not self.enable_lightly_cplb
            else scheduler_config.max_num_seqs * 2
        )
        self.lightly_cp_threshould = envs.VLLM_LIGHTLY_CP_THRESHOULD
395

396
397
398
399
400
        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        self.broadcast_pp_output = (
401
            self.parallel_config.distributed_executor_backend == "external_launcher"
402
            and len(get_pp_group().ranks) > 1
403
        )
404

405
        # Model-related.
406
        self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
407
        self.inputs_embeds_size = model_config.get_inputs_embeds_size()
408
        self.attention_chunk_size = model_config.attention_chunk_size
409
        # Only relevant for models using ALiBi (e.g, MPT)
410
        self.use_alibi = model_config.uses_alibi
411

412
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
413
        self.is_mm_prefix_lm = self.model_config.is_mm_prefix_lm
414

415
        # Multi-modal data support
416
        self.mm_registry = MULTIMODAL_REGISTRY
417
        self.uses_mrope = model_config.uses_mrope
guanyu1's avatar
guanyu1 committed
418
        self.use_1d_mrope = self.uses_mrope and envs.VLLM_1D_MROPE
419
        self.uses_xdrope_dim = model_config.uses_xdrope_dim
420
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
421
422
            model_config
        )
423

424
425
426
        if self.model_config.is_encoder_decoder:
            # Maximum length of the encoder input, only for encoder-decoder
            # models.
427
            self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens
428
429
430
        else:
            self.max_encoder_len = 0

431
432
433
        # Async scheduling
        self.use_async_scheduling = self.scheduler_config.async_scheduling

434
        # Sampler
435
        self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
436

437
        self.eplb_state: EplbState | None = None
438
439
440
441
442
443
        """
        State of the expert parallelism load balancer.

        Will be lazily initialized when the model is loaded.
        """

444
        # Lazy initializations
445
        # self.model: nn.Module  # Set after load_model
446
        # Initialize in initialize_kv_cache
447
        self.kv_caches: list[torch.Tensor] = []
448
449
450
        # Initialize in initialize_kv_cache_tensors
        self.cross_layers_kv_cache: torch.Tensor | None = None
        self.cross_layers_attn_backend: type[AttentionBackend] | None = None
451
452
        # indexes: [kv_cache_group_id][attn_group]
        self.attn_groups: list[list[AttentionGroup]] = []
453
454
        # self.kv_cache_config: KVCacheConfig

455
456
        # mm_hash ->  encoder_output
        self.encoder_cache: dict[str, torch.Tensor] = {}
457

458
        self.use_aux_hidden_state_outputs = False
459
460
461
462
463
        # Set up speculative decoding.
        # NOTE(Jiayi): currently we put the entire draft model on
        # the last PP rank. This is not ideal if there are many
        # layers in the draft model.
        if self.speculative_config and get_pp_group().is_last_rank:
464
            self.drafter: (
465
466
467
468
469
                NgramProposer
                | SuffixDecodingProposer
                | EagleProposer
                | DraftModelProposer
                | MedusaProposer
470
            )
471
472
            if self.speculative_config.method == "ngram":
                self.drafter = NgramProposer(self.vllm_config)
473
474
475
476
477
478
            elif self.speculative_config.uses_draft_model():
                self.drafter = DraftModelProposer(
                    vllm_config=self.vllm_config,
                    device=self.device,
                    runner=self,
                )
479
480
            elif self.speculative_config.method == "suffix":
                self.drafter = SuffixDecodingProposer(self.vllm_config)
481
            elif self.speculative_config.use_eagle():
482
                self.drafter = EagleProposer(self.vllm_config, self.device, self)
483
                if self.speculative_config.method == "eagle3":
484
485
486
                    self.use_aux_hidden_state_outputs = (
                        self.drafter.eagle3_use_aux_hidden_state
                    )
487
488
            elif self.speculative_config.method == "medusa":
                self.drafter = MedusaProposer(
489
                    vllm_config=self.vllm_config, device=self.device
490
                )
491
            else:
492
493
494
495
                raise ValueError(
                    "Unknown speculative decoding method: "
                    f"{self.speculative_config.method}"
                )
王敏's avatar
王敏 committed
496
497
498
499
500
            
            if not envs.VLLM_REJECT_SAMPLE_OPT:
                self.rejection_sampler = RejectionSampler(self.sampler)
            else:
                self.rejection_sampler = OptRejectionSampler(self.sampler)
501

502
503
504
        self.num_spec_tokens = 0
        if self.speculative_config:
            self.num_spec_tokens = self.speculative_config.num_speculative_tokens
505
506
507
508
509
            draft_config = self.speculative_config.draft_model_config
            if draft_config is not None and draft_config.max_model_len is not None:
                self.effective_drafter_max_model_len = draft_config.max_model_len
            else:
                self.effective_drafter_max_model_len = self.max_model_len
510

511
        # Request states.
512
        self.requests: dict[str, CachedRequestState] = {}
513
514
515
        # NOTE(rob): num_prompt_logprobs only includes reqs
        # that are currently in the prefill phase.
        self.num_prompt_logprobs: dict[str, int] = {}
516
        self.comm_stream = torch.cuda.Stream()
517

518
519
520
521
522
523
524
525
526
        # Input Batch
        # NOTE(Chen): Ideally, we should initialize the input batch inside
        # `initialize_kv_cache` based on the kv cache config. However, as in
        # https://github.com/vllm-project/vllm/pull/18298, due to some unknown
        # reasons, we have to initialize the input batch before `load_model`,
        # quantization + weight offloading will fail otherwise. As a temporary
        # solution, we initialize the input batch here, and re-initialize it
        # in `initialize_kv_cache` if the block_sizes here is different from
        # the block_sizes in the kv cache config.
527
528
529
530
        logits_processors = model_config.logits_processors
        custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
            tuple(logits_processors) if logits_processors is not None else ()
        )
531
532
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
533
534
535
            # We need to use the encoder length for encoder-decoer
            # because of KV cache for cross-attention.
            max_model_len=max(self.max_model_len, self.max_encoder_len),
536
537
538
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
539
            vocab_size=self.model_config.get_vocab_size(),
540
            block_sizes=[self.cache_config.block_size],
541
            kernel_block_sizes=[self.cache_config.block_size],
542
            is_spec_decode=bool(self.vllm_config.speculative_config),
543
            logitsprocs=build_logitsprocs(
544
545
546
                self.vllm_config,
                self.device,
                self.pin_memory,
547
                self.is_pooling_model,
548
                custom_logitsprocs,
549
            ),
550
551
552
            # We currently don't know whether a particular custom logits processor
            # uses output token ids so we set this conservatively.
            logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
553
            is_pooling_model=self.is_pooling_model,
554
            cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
555
        )
556

557
558
559
560
561
        # Separate cuda stream for overlapping transfer of sampled token ids from
        # GPU to CPU when async scheduling is enabled.
        self.async_output_copy_stream: torch.cuda.Stream | None = None
        # cuda event to synchronize use of reused CPU tensors between steps
        # when async scheduling is enabled.
562
        self.prepare_inputs_event: torch.Event | None = None
563
564
        if self.use_async_scheduling:
            self.async_output_copy_stream = torch.cuda.Stream()
565
            self.prepare_inputs_event = torch.Event()
566

567
        # self.cudagraph_batch_sizes sorts in ascending order.
568
569
570
571
        if (
            self.compilation_config.cudagraph_capture_sizes
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
572
573
            self.cudagraph_batch_sizes = sorted(
                self.compilation_config.cudagraph_capture_sizes
574
            )
575

576
        # Cache the device properties.
577
        self._init_device_properties()
578

579
580
581
582
        # Encoder timing registry for observability
        self.encoder_timing_registry: dict[str, EncoderTimingStats] = {}
        self._encoder_timing_lock = threading.Lock()

583
        # Persistent buffers for CUDA graphs.
584
585
586
587
588
        self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
        self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
        self.query_start_loc = self._make_buffer(
            self.max_num_reqs + 1, dtype=torch.int32
        )
589
        self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
590
        self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
591
592
593
594
        if self.dcp_world_size > 1:
            self.dcp_local_seq_lens = self._make_buffer(
                self.max_num_reqs, dtype=torch.int32
            )
595
596
597
        # Because inputs_embeds may be bfloat16 and we don't need a numpy
        # version of this tensor, avoid a RuntimeError by not creating a
        # numpy buffer.
598
        self.inputs_embeds = self._make_buffer(
599
            self.max_num_tokens, self.inputs_embeds_size, dtype=self.dtype, numpy=False
600
601
        )
        self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
602
603
        self.discard_request_mask = self._make_buffer(
            self.max_num_reqs, dtype=torch.bool
604
605
606
607
608
609
610
        )
        self.num_decode_draft_tokens = self._make_buffer(
            self.max_num_reqs, dtype=torch.int32
        )
        self.num_accepted_tokens = self._make_buffer(
            self.max_num_reqs, dtype=torch.int64
        )
611

612
613
        # Only relevant for multimodal models
        if self.supports_mm_inputs:
614
615
616
617
618
619
620
            # Double buffer to avoid race condition: previous iteration's async
            # copy may still be reading from CPU while current iteration writes.
            self.is_mm_embed_buffers = [
                self._make_buffer(self.max_num_tokens, dtype=torch.bool),
                self._make_buffer(self.max_num_tokens, dtype=torch.bool),
            ]
            self.is_mm_embed_idx = 0
621
622

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
623
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
624
625
626
627
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
628
629
630
631
632
633

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
guanyu1's avatar
guanyu1 committed
634
635
636
637
638
639
640
641
            if self.use_1d_mrope:
                self.mrope_positions = self._make_buffer(
                    3 * (self.max_num_tokens + 1), dtype=torch.int64
                )
            else:
                self.mrope_positions = self._make_buffer(
                    (3, self.max_num_tokens + 1), dtype=torch.int64
                )
642

643
644
645
646
647
648
        # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
        if self.uses_xdrope_dim > 0:
            # Similar to mrope but use assigned dimension number for RoPE, 4 as default.
            self.xdrope_positions = self._make_buffer(
                (self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
            )
649

650
        # None in the first PP rank. The rest are set after load_model.
651
        self.intermediate_tensors: IntermediateTensors | None = None
652

653
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
654
        # Keep in int64 to avoid overflow with long context
655
656
657
658
        self.arange_np = np.arange(
            max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
            dtype=np.int64,
        )
659

660
661
662
663
664
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}
665
666
667
668
669
        self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()

        self.kv_sharing_fast_prefill_logits_indices = None
        if self.cache_config.kv_sharing_fast_prefill:
            self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
670
671
                self.max_num_tokens, dtype=torch.int32, device=self.device
            )
672

673
        self.uniform_decode_query_len = 1 + self.num_spec_tokens
674
675
676
677

        # Cudagraph dispatcher for runtime cudagraph dispatching.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)

678
        self.mm_budget = (
679
            MultiModalBudget(self.vllm_config, self.mm_registry)
680
681
682
            if self.supports_mm_inputs
            else None
        )
683

684
        self.reorder_batch_threshold: int | None = None
685

686
687
688
689
690
        # Attention layers that are only in the KVCacheConfig of the runner
        # (e.g., KV sharing, encoder-only attention), but not in the
        # KVCacheConfig of the scheduler.
        self.runner_only_attn_layers: set[str] = set()

691
        # Cached outputs.
692
        self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
693
        self._draft_token_req_ids: list[str] | None = None
694
        self.transfer_event = torch.Event()
695
        self.sampled_token_ids_pinned_cpu = torch.empty(
696
            (self.max_num_reqs, 1),
697
698
            dtype=torch.int64,
            device="cpu",
699
700
            pin_memory=self.pin_memory,
        )
701

702
703
        # Pre-allocated tensor for copying valid sampled token counts to CPU,
        # with dedicated stream for overlapping and event for coordination.
704
        self.valid_sampled_token_count_event: torch.Event | None = None
705
        self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        # We also copy the drafted tokens to the CPU asynchronously,
        # in case we need them for structured outputs.
        self.draft_token_ids_event: torch.Event | None = None
        self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None
        self.valid_sampled_token_count_cpu: torch.Tensor | None = None
        self.draft_token_ids_cpu: torch.Tensor | None = None
        if self.num_spec_tokens:
            self.draft_token_ids_event = torch.Event()
            self.draft_token_ids_copy_stream = torch.cuda.Stream()
            self.draft_token_ids_cpu = torch.empty(
                (self.max_num_reqs, self.num_spec_tokens),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory,
            )
            if self.use_async_scheduling:
                self.valid_sampled_token_count_event = torch.Event()
                self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
                self.valid_sampled_token_count_cpu = torch.empty(
                    self.max_num_reqs,
                    dtype=torch.int64,
                    device="cpu",
                    pin_memory=self.pin_memory,
                )
730

731
732
        # Ephemeral state transferred between execute_model() and sample_tokens().
        self.execute_model_state: ExecuteModelState | None = None
733
        self.kv_connector_output: KVConnectorOutput | None = None
734
        self.mamba_state_idx: dict[str, int] = {}
735
        self.layerwise_nvtx_hooks_registered = False
736

王敏's avatar
王敏 committed
737
738
        self.draft_probs : Optional[DraftProbs] = None

739
740
741
742
743
744
745
    def update_max_model_len(self, max_model_len: int) -> None:
        self.max_model_len = max_model_len
        if self.speculative_config:
            draft_config = self.speculative_config.draft_model_config
            if draft_config is None or draft_config.max_model_len is None:
                self.effective_drafter_max_model_len = self.max_model_len

746
747
748
749
    def reset_mm_cache(self) -> None:
        if self.mm_budget:
            self.mm_budget.reset_cache()

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
    @torch.inference_mode()
    def init_fp8_kv_scales(self) -> None:
        """
        Re-initialize the KV cache and FP8 scales after waking from sleep.
        1. Zero out the KV cache tensors to remove garbage data from re-allocation.
        2. Reset Attention layer scaling factors (_k_scale, _v_scale) to 1.0.
          If these are left at 0.0 (default after wake_up), all KV cache values
          become effectively zero, causing gibberish output.
        """
        if not self.cache_config.cache_dtype.startswith("fp8"):
            return

        kv_caches = getattr(self, "kv_caches", [])
        for cache_tensor in kv_caches:
            if cache_tensor is not None:
                cache_tensor.zero_()

        k_attr_names = ("_k_scale", "k_scale")
        v_attr_names = ("_v_scale", "v_scale")

        attn_layers = self.compilation_config.static_forward_context
        for name, module in attn_layers.items():
            if isinstance(module, (Attention, MLAAttention)):
                # TODO: Generally, scale is 1.0 if user uses on-the-fly fp8
                # kvcache quant. However, to get better accuracy, compression
                # frameworks like llm-compressors allow users to tune the
                # scale. We may need to restore the specific calibrated scales
                # here in the future.
                k_scale_val, v_scale_val = 1.0, 1.0

                # Processing K Scale
                for attr in k_attr_names:
                    if hasattr(module, attr):
                        param = getattr(module, attr)
                        if isinstance(param, torch.Tensor):
                            param.fill_(k_scale_val)

                # Processing V Scale
                for attr in v_attr_names:
                    if hasattr(module, attr):
                        param = getattr(module, attr)
                        if isinstance(param, torch.Tensor):
                            param.fill_(v_scale_val)

794
795
796
    def _get_positions(self, num_tokens: Any):
        if isinstance(num_tokens, int):
            if self.uses_mrope:
guanyu1's avatar
guanyu1 committed
797
798
799
800
                if self.use_1d_mrope:
                    return self.mrope_positions.gpu[: 3 * num_tokens].view(
                        num_tokens, 3
                    ).T
801
                return self.mrope_positions.gpu[:, :num_tokens]
802
803
            if self.uses_xdrope_dim > 0:
                return self.xdrope_positions.gpu[:, :num_tokens]
804
805
806
            return self.positions.gpu[:num_tokens]
        else:
            if self.uses_mrope:
guanyu1's avatar
guanyu1 committed
807
808
                if self.use_1d_mrope:
                    return self.mrope_positions.gpu.view(-1, 3)[num_tokens].T
809
                return self.mrope_positions.gpu[:, num_tokens]
810
811
            if self.uses_xdrope_dim > 0:
                return self.xdrope_positions.gpu[:, num_tokens]
812
813
            return self.positions.gpu[num_tokens]

guanyu1's avatar
guanyu1 committed
814

815
    def _make_buffer(
816
        self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True
817
818
819
820
821
822
823
824
    ) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *size,
            dtype=dtype,
            device=self.device,
            pin_memory=self.pin_memory,
            with_numpy=numpy,
        )
guanyu1's avatar
guanyu1 committed
825

guanyu1's avatar
guanyu1 committed
826
827
828
    def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None:
        if not self.uses_mrope:
            return
guanyu1's avatar
guanyu1 committed
829
830
831
832
833
834
835
        if self.use_1d_mrope:
            num_values = 3 * num_tokens
            self.mrope_positions.gpu[:num_values].copy_(
                self.mrope_positions.cpu[:num_values],
                non_blocking=True,
            )
            return
836
837
838
839
        # self.mrope_positions.gpu[:, :num_tokens].copy_(
        #     self.mrope_positions.cpu[:, :num_tokens],
        #     non_blocking=True,
        # )
guanyu1's avatar
guanyu1 committed
840
        self.mrope_positions.gpu[:, :num_tokens].copy_(
841
            self.mrope_positions.cpu[:, :num_tokens].contiguous().pin_memory(),
guanyu1's avatar
guanyu1 committed
842
843
844
845
846
847
            non_blocking=True,
        )

    def _copy_xdrope_positions_to_gpu(self, num_tokens: int) -> None:
        if self.uses_xdrope_dim <= 0:
            return
guanyu1's avatar
guanyu1 committed
848

guanyu1's avatar
guanyu1 committed
849
850
851
852
853
        self.xdrope_positions.gpu[:, :num_tokens].copy_(
            self.xdrope_positions.cpu[:, :num_tokens],
            non_blocking=True,
        )

854

855
    def _init_model_kwargs(self):
856
857
        model_kwargs = dict[str, Any]()

858
        if not self.is_pooling_model:
859
860
            return model_kwargs

861
862
        num_reqs = self.input_batch.num_reqs
        pooling_params = self.input_batch.get_pooling_params()
863
864
865

        token_type_id_requests = dict[int, Any]()
        for i, param in enumerate(pooling_params):
866
867
868
869
870
            if (
                param.extra_kwargs is not None
                and (token_types := param.extra_kwargs.get("compressed_token_type_ids"))
                is not None
            ):
871
872
873
874
875
                token_type_id_requests[i] = token_types

        if len(token_type_id_requests) == 0:
            return model_kwargs

876
        seq_lens = self.seq_lens.gpu[:num_reqs]
877
878
879
880
881
882
883
884
        token_type_ids = []

        for i in range(num_reqs):
            pos = token_type_id_requests.get(i, seq_lens[i])
            ids = (torch.arange(seq_lens[i]) >= pos).int()
            token_type_ids.append(ids)

        model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
885
886
            device=self.device
        )
887
        return model_kwargs
888

889
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
890
891
        """
        Update the order of requests in the batch based on the attention
892
        backend's needs. For example, some attention backends (namely MLA) may
893
894
895
896
897
898
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            scheduler_output: The scheduler output.
        """
899
900
901
902
903
904
905
906
        # Attention free models have zero kv_cache_goups, however models
        # like Mamba are also attention free but use the kv_cache for
        # keeping its internal state. This is why we check the number
        # of kv_cache groups instead of solely checking
        # for self.model_config.is_attention_free.
        if len(self.kv_cache_config.kv_cache_groups) == 0:
            return

907
908
909
910
        if self.reorder_batch_threshold is not None:
            reorder_batch_to_split_decodes_and_prefills(
                self.input_batch,
                scheduler_output,
911
912
                decode_threshold=self.reorder_batch_threshold,
            )
913

914
915
    # Note: used for model runner override.
    def _init_device_properties(self) -> None:
916
        """Initialize attributes from torch.cuda.get_device_properties"""
917
918
919
920
921
922
923
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

    # Note: used for model runner override.
    def _sync_device(self) -> None:
        torch.cuda.synchronize()

924
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
925
926
927
928
929
930
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

931
932
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
933
        """
王敏's avatar
王敏 committed
934
        if scheduler_output.total_num_scheduled_tokens == 0 and not self.use_async_scheduling:
935
            return
936
        # Remove finished requests from the cached states.
937
938
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
939
            self.num_prompt_logprobs.pop(req_id, None)
940
941
942
943
944
945
946
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        for req_id in scheduler_output.finished_req_ids:
947
            self.input_batch.remove_request(req_id)
948

王敏's avatar
王敏 committed
949
950
951
952
        # prune draft probs of finished requests
        if envs.VLLM_REJECT_SAMPLE_OPT and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
            self.draft_probs.prune(list(scheduler_output.finished_req_ids))

953
        # Free the cached encoder outputs.
954
955
        for mm_hash in scheduler_output.free_encoder_mm_hashes:
            self.encoder_cache.pop(mm_hash, None)
956

957
958
959
960
961
962
963
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
964
965
966
967
968
969
970
971
        resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
        # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint,
        # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds
        # apart from the forced-preemption case in reset_prefix_cache. And in
        # that case we include the resumed_req_ids in the unscheduled set so
        # that they get cleared from the persistent batch before being re-scheduled
        # in the normal resumed request path.
        unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids)
972
973
974
975
976
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
977
            self.input_batch.remove_request(req_id)
978

979
        reqs_to_add: list[CachedRequestState] = []
980
981
982
983
984
985
986
987

        # Track re-added requests on non-last ranks that need token_ids_cpu
        # fix-up after add_request.  On non-last ranks, output_token_ids
        # does NOT include accepted draft tokens, so add_request() places
        # tokens at wrong positions.  We save (new_token_ids, num_computed)
        # here and fix up token_ids_cpu right after add_request.
        fix_tokens_map: dict[str, tuple[list[int], int]] = {}

988
        # Add new requests to the cached states.
989
990
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
991
992
993
994
995
996
            if req_id in self.requests:
                # For streaming case only.
                req_state = self._update_streaming_request(req_id, new_req_data)
                reqs_to_add.append(req_state)
                continue

997
            sampling_params = new_req_data.sampling_params
998
            pooling_params = new_req_data.pooling_params
999

1000
1001
1002
1003
            if (
                sampling_params
                and sampling_params.sampling_type == SamplingType.RANDOM_SEED
            ):
1004
1005
1006
1007
1008
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

1009
1010
            if self.is_pooling_model:
                assert pooling_params is not None
1011
1012
                task = pooling_params.task
                assert task is not None, "You did not set `task` in the API"
1013

1014
                model = cast(VllmModelForPooling, self.get_model())
1015
                to_update = model.pooler.get_pooling_updates(task)
1016
1017
                to_update.apply(pooling_params)

1018
            req_state = CachedRequestState(
1019
                req_id=req_id,
1020
                prompt_token_ids=new_req_data.prompt_token_ids,
1021
                prompt_embeds=new_req_data.prompt_embeds,
1022
                mm_features=new_req_data.mm_features,
1023
                sampling_params=sampling_params,
1024
                pooling_params=pooling_params,
1025
                generator=generator,
1026
1027
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
1028
                output_token_ids=[],
1029
                lora_request=new_req_data.lora_request,
1030
            )
1031
            self.requests[req_id] = req_state
1032

1033
1034
1035
1036
1037
1038
1039
            if sampling_params and sampling_params.prompt_logprobs is not None:
                self.num_prompt_logprobs[req_id] = (
                    self.input_batch.vocab_size
                    if sampling_params.prompt_logprobs == -1
                    else sampling_params.prompt_logprobs
                )

1040
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
1041
            if self.uses_mrope:
1042
                self._init_mrope_positions(req_state)
1043

1044
1045
1046
1047
            # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
            if self.uses_xdrope_dim > 0:
                self._init_xdrope_positions(req_state)

1048
            reqs_to_add.append(req_state)
1049

1050
        # Update the states of the running/resumed requests.
1051
        is_last_rank = get_pp_group().is_last_rank
1052
        req_data = scheduler_output.scheduled_cached_reqs
1053
        scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
1054
1055
1056
1057
1058

        # Wait until valid_sampled_tokens_count is copied to cpu,
        # then use it to update actual num_computed_tokens of each request.
        valid_sampled_token_count = self._get_valid_sampled_token_count()

1059
        for i, req_id in enumerate(req_data.req_ids):
1060
            req_state = self.requests[req_id]
1061
1062
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
1063
            resumed_from_preemption = req_id in req_data.resumed_req_ids
1064
            num_output_tokens = req_data.num_output_tokens[i]
1065
            req_index = self.input_batch.req_id_to_index.get(req_id)
1066

1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            if req_state.prev_num_draft_len and self.use_async_scheduling:
                # prev_num_draft_len is used in async scheduling mode with
                # spec decode. it indicates if need to update num_computed_tokens
                # of the request. for example:
                # fist step: num_computed_tokens = 0, spec_tokens = [],
                # prev_num_draft_len = 0.
                # second step: num_computed_tokens = 100(prompt lenth),
                # spec_tokens = [a,b], prev_num_draft_len = 0.
                # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
                # prev_num_draft_len = 2.
                # num_computed_tokens in first step and second step does't contain
                # the spec tokens length, but in third step it contains the
                # spec tokens length. we only need to update num_computed_tokens
                # when prev_num_draft_len > 0.
1081
1082
1083
1084
1085
1086
1087
1088
1089
                if req_index is None:
                    req_state.prev_num_draft_len = 0
                else:
                    assert self.input_batch.prev_req_id_to_index is not None
                    prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
                    num_accepted = valid_sampled_token_count[prev_req_index] - 1
                    num_rejected = req_state.prev_num_draft_len - num_accepted
                    num_computed_tokens -= num_rejected
                    req_state.output_token_ids.extend([-1] * num_accepted)
1090

1091
            # Update the cached states.
1092
            req_state.num_computed_tokens = num_computed_tokens
1093
1094

            if not is_last_rank:
王敏's avatar
王敏 committed
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
                # When using PP, the scheduler sends the sampled tokens back,
                # because there's no direct communication between the first-
                # stage worker and the last-stage worker.
                new_token_ids = req_data.new_token_ids[i]
                # Add the sampled token(s) from the previous step (if any).
                # This doesn't include "unverified" tokens like spec tokens.
                num_new_tokens = (
                    num_computed_tokens + len(new_token_ids) - req_state.num_tokens
                )
                if num_new_tokens == 1:
                    # Avoid slicing list in most common case.
                    req_state.output_token_ids.append(new_token_ids[-1])
                elif num_new_tokens > 0:
                    req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
1109
1110
1111
1112
1113
            elif num_output_tokens < len(req_state.output_token_ids):
                # Some output tokens were discarded due to a sync-KV-load
                # failure. Align the cached state.
                del req_state.output_token_ids[num_output_tokens:]
                if req_index is not None:
1114
1115
1116
1117
                    end_idx = (
                        self.input_batch.num_prompt_tokens[req_index]
                        + num_output_tokens
                    )
1118
                    self.input_batch.num_tokens_no_spec[req_index] = end_idx
1119

1120
            # Update the block IDs.
1121
            if not resumed_from_preemption:
1122
1123
                if new_block_ids is not None:
                    # Append the new blocks to the existing block IDs.
1124
                    for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
1125
                        block_ids.extend(new_ids)
1126
            else:
1127
                assert req_index is None
1128
                assert new_block_ids is not None
1129
1130
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
1131
                req_state.block_ids = new_block_ids
1132
1133
1134
1135
1136

            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
1137
1138
1139
1140
1141
1142
1143

                if self.use_async_scheduling and num_output_tokens > 0:
                    # We must recover the output token ids for resumed requests in the
                    # async scheduling case, so that correct input_ids are obtained.
                    resumed_token_ids = req_data.all_token_ids[req_id]
                    req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]

1144
1145
1146
1147
1148
1149
1150
                # On non-last ranks with PP + spec decode, output_token_ids
                # doesn't include accepted draft tokens.  Save the fix-up
                # data so we can correct token_ids_cpu after add_request.
                if not is_last_rank and new_token_ids:
                    fix_tokens_map[req_id] = (
                        list(new_token_ids), num_computed_tokens)

1151
                reqs_to_add.append(req_state)
1152
1153
1154
                continue

            # Update the persistent batch.
1155
            self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
1156
            if new_block_ids is not None:
1157
                self.input_batch.block_table.append_row(new_block_ids, req_index)
1158
1159
1160
1161
1162
1163

            # For the last rank, we don't need to update the token_ids_cpu
            # because the sampled tokens are already cached.
            if not is_last_rank:
                # Add new_token_ids to token_ids_cpu.
                start_token_index = num_computed_tokens
zhuwenwen's avatar
zhuwenwen committed
1164
                end_token_index = num_computed_tokens + len(new_token_ids)
1165
                self.input_batch.token_ids_cpu[
1166
1167
1168
                    req_index, start_token_index:end_token_index
                ] = new_token_ids
                self.input_batch.num_tokens_no_spec[req_index] = end_token_index
1169

1170
            # Add spec_token_ids to token_ids_cpu.
1171
            self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
1172

1173
1174
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
1175
1176
        for request in reqs_to_add:
            self.input_batch.add_request(request)
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
            req_id = request.req_id
            req_index = self.input_batch.req_id_to_index[req_id]

            # Fix token_ids_cpu for re-added requests on non-last PP ranks.
            # add_request() copies output_token_ids to token_ids_cpu, but on
            # non-last ranks output_token_ids does NOT include accepted draft
            # tokens, causing tokens to land at wrong positions.  Overwrite
            # the new tokens at the correct position (num_computed_tokens)
            # and adjust num_tokens_no_spec before placing spec tokens.
            fix_data = fix_tokens_map.get(req_id)
            if fix_data is not None:
                new_toks, n_computed = fix_data
                start = n_computed
                end = start + len(new_toks)
                self.input_batch.token_ids_cpu[req_index, start:end] = new_toks
                self.input_batch.num_tokens_no_spec[req_index] = end

            # Place spec tokens at the (now-correct) num_tokens_no_spec offset.
            self.input_batch.update_req_spec_token_ids(
                request, scheduled_spec_tokens)
1197

1198
1199
1200
1201
1202
        # Condense the batched states if there are gaps left by removed requests
        self.input_batch.condense()
        # Allow attention backend to reorder the batch, potentially
        self._may_reorder_batch(scheduler_output)
        # Refresh batch metadata with any pending updates.
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
        repeat_counts = None
        if envs.VLLM_REJECT_SAMPLE_OPT and \
                scheduler_output.scheduled_spec_decode_tokens:
            repeat_counts = [1] * self.input_batch.num_reqs
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index.get(req_id)
                if req_idx is not None:
                    repeat_counts[req_idx] += len(draft_token_ids)
            repeat_counts = torch.tensor(repeat_counts, dtype=torch.int32, device="cpu")
        self.input_batch.refresh_metadata(repeat_counts)
1214

1215
    def _update_states_after_model_execute(
1216
        self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
1217
    ) -> None:
1218
1219
1220
1221
1222
1223
1224
1225
        """Update the cached states after model execution.

        This is used for MTP/EAGLE for hybrid models, as in linear attention,
        only the last token's state is kept. In MTP/EAGLE, for draft tokens
        the state are kept util we decide how many tokens are accepted for
        each sequence, and a shifting is done during the next iteration
        based on the number of accepted tokens.
        """
1226
        if not self.speculative_config or not self.model_config.is_hybrid:
1227
1228
1229
            return

        # Find the number of accepted tokens for each sequence.
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
        num_accepted_tokens = (
            (
                torch.cat(
                    [
                        output_token_ids,
                        torch.full(
                            (output_token_ids.size(0), 1),
                            -1,
                            device=output_token_ids.device,
                        ),
                    ],
                    dim=1,
                )
                == -1
            )
            .int()
            .argmax(-1)
            .cpu()
            .numpy()
        )
1250
1251
        for i, num_tokens in enumerate(num_accepted_tokens):
            self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
        if self.cache_config.mamba_cache_mode == "align":
            mamba_utils.postprocess_mamba(
                scheduler_output,
                self.kv_cache_config,
                self.input_batch,
                self.requests,
                self.mamba_state_idx,
                self.compilation_config.static_forward_context,
                self.model.get_mamba_state_copy_func(),
            )
1262

1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
    def _update_streaming_request(
        self, req_id: str, new_req_data: NewRequestData
    ) -> CachedRequestState:
        """Updates streaming session request from `scheduled_new_reqs`.

        Removes the request from InputBatch (if present), updates the cached
        state, and prepares it for re-addition to the batch.

        NOTE: prompt_token_ids includes intermediate output tokens - tokens
        previously generated but now are input context (part of the prompt).
        """
        self.input_batch.remove_request(req_id)
        req_state = self.requests[req_id]

        req_state.prompt_token_ids = new_req_data.prompt_token_ids
        req_state.mm_features = new_req_data.mm_features
        req_state.prompt_embeds = new_req_data.prompt_embeds
        req_state.sampling_params = new_req_data.sampling_params
        req_state.pooling_params = new_req_data.pooling_params
        req_state.block_ids = new_req_data.block_ids
        req_state.num_computed_tokens = new_req_data.num_computed_tokens
        req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
            req_state.prompt_token_ids, req_state.prompt_embeds
        )

        # Clear `output_token_ids` as previous output tokens are now part of
        # `prompt_token_ids`.
        req_state.output_token_ids.clear()

        if self.uses_mrope:
            self._init_mrope_positions(req_state)

        return req_state
1296

1297
    def _init_mrope_positions(self, req_state: CachedRequestState):
1298
1299
        model = self.get_model()
        assert supports_mrope(model), "M-RoPE support is not implemented."
1300
1301
1302
1303
        assert req_state.prompt_token_ids is not None, (
            "M-RoPE requires prompt_token_ids to be available."
        )
        mrope_model = cast(SupportsMRoPE, model)
1304
1305

        req_state.mrope_positions, req_state.mrope_position_delta = (
1306
            mrope_model.get_mrope_input_positions(
1307
                req_state.prompt_token_ids,
1308
                req_state.mm_features,
1309
            )
1310
        )
1311

1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
    def _init_xdrope_positions(self, req_state: CachedRequestState):
        model = self.get_model()
        xdrope_model = cast(SupportsXDRoPE, model)
        assert req_state.prompt_token_ids is not None, (
            "XD-RoPE requires prompt_token_ids to be available."
        )
        assert supports_xdrope(model), "XD-RoPE support is not implemented."

        req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
            req_state.prompt_token_ids,
            req_state.mm_features,
        )
1324

1325
    def _extract_mm_kwargs(
1326
        self,
1327
1328
        scheduler_output: "SchedulerOutput",
    ) -> BatchedTensorInputs:
1329
        if not scheduler_output or not self.is_multimodal_raw_input_only_model:
1330
            return {}
1331

1332
1333
        mm_kwargs = list[MultiModalKwargsItem]()
        for req in scheduler_output.scheduled_new_reqs:
1334
1335
1336
            for feature in req.mm_features:
                if feature.data is not None:
                    mm_kwargs.append(feature.data)
1337

1338
1339
1340
        # Input all modalities at once
        mm_kwargs_combined: BatchedTensorInputs = {}
        for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
1341
1342
1343
            mm_kwargs,
            device=self.device,
            pin_memory=self.pin_memory,
1344
1345
        ):
            mm_kwargs_combined.update(mm_kwargs_group)
1346

1347
        return mm_kwargs_combined
1348
1349

    def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
1350
        if not self.is_multimodal_raw_input_only_model:
1351
            return {}
1352

1353
1354
        mm_budget = self.mm_budget
        assert mm_budget is not None
1355

1356
1357
        dummy_modality = mm_budget.get_modality_with_max_tokens()
        return self._get_mm_dummy_batch(dummy_modality, num_seqs)
1358

1359
1360
1361
    def _get_cumsum_and_arange(
        self,
        num_tokens: np.ndarray,
1362
        cumsum_dtype: np.dtype | None = None,
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
    ) -> tuple[np.ndarray, np.ndarray]:
        """Get the cumulative sum and batched arange of the given array.
        # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_tokens])
        """
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
        total_num_tokens = cu_num_tokens[-1]
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_tokens] - cumsums_offsets

        return cu_num_tokens, arange

1379
    def _prepare_input_ids(
1380
1381
1382
1383
        self,
        scheduler_output: "SchedulerOutput",
        total_num_scheduled_tokens: int,
        cu_num_tokens: np.ndarray,
1384
    ) -> None:
1385
        """Prepare the input IDs for the current batch.
1386

1387
1388
1389
1390
1391
1392
1393
        Carefully handles the `prev_sampled_token_ids` which can be cached
        from the previous engine iteration, in which case those tokens on the
        GPU need to be copied into the corresponding slots into input_ids."""

        if self.input_batch.prev_sampled_token_ids is None:
            # Normal scheduling case
            self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
1394
1395
1396
            if self.enable_prompt_embeds:
                self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
                self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
1397
1398
1399
1400
1401
1402
1403
            return

        # Async scheduling case, where some decode requests from the previous
        # iteration won't have entries in input_ids_cpu and need to be copied
        # on the GPU from prev_sampled_token_ids.
        prev_req_id_to_index = self.input_batch.prev_req_id_to_index
        assert prev_req_id_to_index is not None
1404
1405
1406
1407
        sample_flattened_indices: list[int] = []
        spec_flattened_indices: list[int] = []
        prev_common_req_indices: list[int] = []
        prev_draft_token_indices: list[int] = []
1408
1409
        indices_match = True
        max_flattened_index = -1
1410
1411
1412
        total_num_spec_tokens = 0
        scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens

1413
1414
1415
1416
1417
        for req_id, cur_index in self.input_batch.req_id_to_index.items():
            if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
                prev_common_req_indices.append(prev_index)
                # We need to compute the flattened input_ids index of the
                # last token in each common request.
1418
1419
                draft_len = len(scheduled_spec_tokens.get(req_id, ()))
                total_num_spec_tokens += draft_len
1420
                flattened_index = cu_num_tokens[cur_index].item() - 1
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
                # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
                # sample_flattened_indices = [0, 2, 5]
                # spec_flattened_indices = [1,   3, 4,    6, 7]
                sample_flattened_indices.append(flattened_index - draft_len)
                spec_flattened_indices.extend(
                    range(flattened_index - draft_len + 1, flattened_index + 1)
                )
                start = prev_index * self.num_spec_tokens
                # prev_draft_token_indices is used to find which draft_tokens_id
                # should be copied to input_ids
                # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
                # flatten draft_tokens_id [1,2,3,4,5,6]
                # draft_len of each request [1, 2, 1]
                # then prev_draft_token_indices is [0,   2, 3,   4]
                prev_draft_token_indices.extend(range(start, start + draft_len))
1436
                indices_match &= prev_index == flattened_index
1437
                max_flattened_index = max(max_flattened_index, flattened_index)
1438
1439
1440
        num_commmon_tokens = len(sample_flattened_indices)
        total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
        if num_commmon_tokens < total_without_spec:
1441
1442
1443
            # If not all requests are decodes from the last iteration,
            # We need to copy the input_ids_cpu to the GPU first.
            self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
1444
1445
1446
            if self.enable_prompt_embeds:
                self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
                self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
1447
1448
        if num_commmon_tokens == 0:
            # No requests in common with the previous iteration
1449
            # So input_ids.cpu will have all the input ids.
1450
1451
1452
1453
1454
1455
1456
            return
        if indices_match and max_flattened_index == (num_commmon_tokens - 1):
            # Common-case optimization: the batch is unchanged
            # and no reordering happened.
            # The indices are both the same permutation of 0..N-1 so
            # we can copy directly using a single slice.
            self.input_ids.gpu[:num_commmon_tokens].copy_(
1457
1458
1459
                self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0],
                non_blocking=True,
            )
1460
1461
            if self.enable_prompt_embeds:
                self.is_token_ids.gpu[:num_commmon_tokens] = True
1462
            return
1463
        # Upload the index tensors asynchronously so the scatter can be non-blocking.
1464
1465
        sampled_tokens_index_tensor = torch.tensor(
            sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
1466
        ).to(self.device, non_blocking=True)
1467
        prev_common_req_indices_tensor = torch.tensor(
1468
1469
            prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
        ).to(self.device, non_blocking=True)
1470
1471
        self.input_ids.gpu.scatter_(
            dim=0,
1472
            index=sampled_tokens_index_tensor,
1473
            src=self.input_batch.prev_sampled_token_ids[
1474
1475
1476
                prev_common_req_indices_tensor, 0
            ],
        )
1477

1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
        # Scatter the draft tokens after the sampled tokens are scattered.
        if self._draft_token_ids is None or not spec_flattened_indices:
            return

        assert isinstance(self._draft_token_ids, torch.Tensor)
        draft_tokens_index_tensor = torch.tensor(
            spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
        ).to(self.device, non_blocking=True)
        prev_draft_token_indices_tensor = torch.tensor(
            prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
        ).to(self.device, non_blocking=True)

        # because input_ids dtype is torch.int32,
        # so convert draft_token_ids to torch.int32 here.
        draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)

        self.input_ids.gpu.scatter_(
            dim=0,
            index=draft_tokens_index_tensor,
            src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
        )
1499

1500
1501
    def _get_encoder_seq_lens(
        self,
1502
        num_scheduled_tokens: dict[str, int],
1503
1504
        kv_cache_spec: KVCacheSpec,
        num_reqs: int,
1505
        for_cudagraph_capture: bool = False,
1506
    ) -> tuple[torch.Tensor | None, np.ndarray | None]:
1507
        if not isinstance(kv_cache_spec, CrossAttentionSpec):
1508
            return None, None
1509

1510
1511
        # Zero out buffer for padding requests that are not actually scheduled (CGs)
        self.encoder_seq_lens.np[:num_reqs] = 0
1512

1513
1514
        # Build encoder_seq_lens array mapping request indices to
        # encoder lengths for inputs scheduled in this batch
1515
        for req_id in num_scheduled_tokens:
1516
            req_index = self.input_batch.req_id_to_index[req_id]
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
            req_state = self.requests[req_id]
            if req_state.mm_features is None:
                self.encoder_seq_lens.np[req_index] = 0
                continue

            # Get the total number of encoder input tokens for running encoder requests
            # whether encoding is finished or not so that cross-attention knows how
            # many encoder tokens to attend to.
            encoder_input_tokens = sum(
                feature.mm_position.length for feature in req_state.mm_features
            )
            self.encoder_seq_lens.np[req_index] = encoder_input_tokens
1529
1530
1531
1532
1533
1534
1535
1536
1537
        if for_cudagraph_capture:
            # During CUDA graph capture, we need to use realistic encoder lengths
            # so that max_seqlen_k is captured with the correct value.
            max_encoder_len = getattr(
                self.model_config.hf_config,
                "max_source_positions",
                self.max_encoder_len,
            )
            self.encoder_seq_lens.np[:num_reqs] = max_encoder_len
1538

1539
1540
1541
        self.encoder_seq_lens.copy_to_gpu(num_reqs)
        encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
        encoder_seq_lens_cpu = self.encoder_seq_lens.np[:num_reqs]
1542

1543
        return encoder_seq_lens, encoder_seq_lens_cpu
1544

1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
    def _distribute_tokens_to_cp_ranks(
        self,
        total_q_len: int,
        q_lens_cpu: np.ndarray,
        kv_lens_cpu: np.ndarray,
        tp_rank: int,
        tp_size: int,
        req_ids: list[str],
    ):
        q_lens = []
        seq_count = 0
        seq_indexes = []
        kv_lens = []
        local_req_ids = []

        local_scatter_indexes_tensor = None
        gather_indexes_tensor = None

        if self.enable_lightly_cplb:
            rank_tokens = 0
            rank_pad_tokens = 0
            accu_q_start = 0
            scatter_indexes: list[int] = []
            num_requests = len(q_lens_cpu)
            for i in range(num_requests):
                req_q_len = q_lens_cpu[i]
                req_pad_q_len = round_up(q_lens_cpu[i], 2 * tp_size)
                kv_len = kv_lens_cpu[i]

                chunk_q_len = req_pad_q_len // (2 * tp_size)

                q_1_start = tp_rank * chunk_q_len
                q_1_end = (tp_rank + 1) * chunk_q_len
                q_2_start = req_pad_q_len - (tp_rank + 1) * chunk_q_len
                q_2_end = req_pad_q_len - tp_rank * chunk_q_len

                q_len_1 = (
                    chunk_q_len
                    if q_1_end <= req_q_len
                    else max(0, req_q_len - q_1_start)
                )
                q_len_2 = (
                    chunk_q_len
                    if q_2_end <= req_q_len
                    else max(0, req_q_len - q_2_start)
                )

                kv_len_1 = kv_len - req_q_len + min(req_q_len, q_1_end)
                kv_len_2 = kv_len - req_q_len + min(req_q_len, q_2_end)

                scatter_index1 = range(
                    accu_q_start + q_1_start, accu_q_start + q_1_start + q_len_1
                )

                scatter_index2 = range(
                    accu_q_start + q_2_start, accu_q_start + q_2_start + q_len_2
                )
                accu_q_start += req_q_len

                if q_len_1 > 0:
                    q_lens.append(q_len_1)
                    kv_lens.append(kv_len_1)
                    seq_indexes.append(i)
                    local_req_ids.append(req_ids[i])
                    scatter_indexes.extend(scatter_index1)
                    seq_count += 1
                    rank_tokens += q_len_1

                if q_len_2 > 0:
                    q_lens.append(q_len_2)
                    kv_lens.append(kv_len_2)
                    seq_indexes.append(i)
                    local_req_ids.append(req_ids[i])
                    scatter_indexes.extend(scatter_index2)
                    seq_count += 1
                    rank_tokens += q_len_2

                rank_pad_tokens += chunk_q_len * 2

            if len(scatter_indexes) < rank_pad_tokens:
                scatter_indexes.extend([-1] * (rank_pad_tokens - len(scatter_indexes)))

            local_scatter_indexes_tensor = torch.tensor(
                scatter_indexes, dtype=torch.int64, device=self.device
            )
            global_scatter_indexes_tensor = tensor_model_parallel_all_gather(
                local_scatter_indexes_tensor.contiguous(), dim=0
            )
            non_neg_mask = global_scatter_indexes_tensor != -1
            non_neg_values = global_scatter_indexes_tensor[non_neg_mask]
            non_neg_positions = torch.where(non_neg_mask)[0]
            sorted_indices = torch.argsort(non_neg_values)
            gather_indexes_tensor = non_neg_positions[sorted_indices]
            if isinstance(rank_tokens, torch.Tensor):
                rank_tokens = rank_tokens.item()
        else:
            tokens_per_rank = (total_q_len + tp_size - 1) // tp_size
            start_token = tp_rank * tokens_per_rank
            end_token = min((tp_rank + 1) * tokens_per_rank, total_q_len)
            
            current_seq = 0
            current_pos = 0
            rank_tokens = min(tokens_per_rank, end_token - start_token)
            while start_token < end_token and current_seq < len(q_lens_cpu):
                q_len = q_lens_cpu[current_seq]
                q_start = current_pos
                q_end = current_pos + q_len
                kv_len = kv_lens_cpu[current_seq]

                # Find overlap between this sequence and rank's token range
                overlap_start = max(start_token, q_start)
                overlap_end = min(end_token, q_end)

                if overlap_start < overlap_end:
                    # This sequence contributes tokens to this rank
                    token_count = overlap_end - overlap_start
                    q_lens.append(token_count)
                    start_token = overlap_end
                    seq_count += 1
                    seq_indexes.append(current_seq)
                    local_req_ids.append(req_ids[current_seq])

                    if q_end <= end_token:
                        kv_lens.append(kv_len)
                    else:
                        kv_lens.append(kv_len - (q_end - end_token))

                current_pos = q_end
                current_seq += 1

        return (
            rank_tokens,
            np.array(q_lens, dtype=np.int32),
            seq_count,
            np.array(kv_lens, dtype=np.int32),
            np.array(local_req_ids, dtype=str),
            local_scatter_indexes_tensor,
            gather_indexes_tensor,
            seq_indexes,
        )

    def _prepare_cp_metadata(
        self,
        num_reqs_padded,
        max_query_len,
        max_seq_len,
        num_tokens,
        block_table_gid_0,
        slot_mapping_gid_0,
        num_computed_tokens_cpu
    ):
        tp_size = self.vllm_config.parallel_config.tensor_parallel_size
        tp_rank = get_tensor_model_parallel_rank()

        cp_common_metadata = CpCommonAttentionMetadata(
            query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1].clone(),
            query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1].clone(),
            seq_lens=self.seq_lens.gpu[:num_reqs_padded].clone(),
            _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded].clone(),
            max_query_len=max_query_len,
            max_seq_len=max_seq_len,
            num_reqs=num_reqs_padded,
            req_ids=self.input_batch.req_ids,
            num_actual_tokens=num_tokens,
            num_kv_actual_tokens=num_tokens,
            block_table_tensor=block_table_gid_0,
            slot_mapping=slot_mapping_gid_0,
            _num_computed_tokens_cpu=num_computed_tokens_cpu
        )

        query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
        q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
        kv_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
        total_q_len = num_tokens
        total_kv_len = num_tokens

        (
            total_q_len,
            q_lens_cpu,
            seq_count,
            kv_lens_cpu,
            local_req_ids,
            scatter_indexes_tensor,
            gather_indexes_tensor,
            seq_indexes_list,
        ) = self._distribute_tokens_to_cp_ranks(
            total_q_len,
            q_lens_cpu,
            kv_lens_cpu,
            tp_rank,
            tp_size,
            self.input_batch.req_ids,
        )

        num_reqs = seq_count

        cu_num_tokens = np.cumsum(q_lens_cpu)
        self.query_start_loc.np[0] = 0
        self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
        self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
        self.query_start_loc.copy_to_gpu()
        q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
        q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
        max_q_len = max(q_acc_lens_cpu)

        self.seq_lens.np[:num_reqs] = kv_lens_cpu
        self.seq_lens.np[num_reqs:].fill(0)
        self.seq_lens.copy_to_gpu()
        kv_lens = self.seq_lens.gpu[:num_reqs]
        kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
        max_kv_len = max(kv_lens_cpu)

        num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
        blk_table_tensor = block_table_gid_0[seq_indexes_list]

        cm_base = CommonAttentionMetadata(
            query_start_loc=q_acc_lens,
            query_start_loc_cpu=q_acc_lens_cpu,
            seq_lens=kv_lens,
            _seq_lens_cpu=kv_lens_cpu,
            _num_computed_tokens_cpu=num_computed_tokens_cpu,
            num_reqs=num_reqs,
            num_actual_tokens=total_q_len,
            max_query_len=max_q_len,
            max_seq_len=max_kv_len,
            block_table_tensor=blk_table_tensor,
            slot_mapping=slot_mapping_gid_0,
            causal=True,
            num_kv_actual_tokens=total_kv_len,
            seq_indexes_list=seq_indexes_list,
            cp_common_metadata=cp_common_metadata,
            scatter_indexes_tensor=scatter_indexes_tensor,
            gather_indexes_tensor=gather_indexes_tensor,
            enable_lightly_cp=True
        )
        return cm_base

1782
    def _prepare_inputs(
1783
1784
1785
        self,
        scheduler_output: "SchedulerOutput",
        num_scheduled_tokens: np.ndarray,
1786
1787
    ) -> tuple[
        torch.Tensor,
1788
        SpecDecodeMetadata | None,
1789
    ]:
1790
1791
        """
        :return: tuple[
1792
            logits_indices, spec_decode_metadata,
1793
1794
        ]
        """
1795
1796
1797
1798
1799
1800
1801
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
1802
        self.input_batch.block_table.commit_block_table(num_reqs)
1803
1804
1805

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
1806
        req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
1807

1808
1809
        # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
        # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1810
        cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
1811
1812

        # Get positions.
1813
        positions_np = self.positions.np[:total_num_scheduled_tokens]
1814
1815
1816
1817
1818
        np.add(
            self.input_batch.num_computed_tokens_cpu[req_indices],
            arange,
            out=positions_np,
        )
1819

1820
1821
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
1822
        if self.uses_mrope:
1823
1824
            self._calc_mrope_positions(scheduler_output)

1825
1826
1827
1828
1829
        # Calculate XD-RoPE positions.
        # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
        if self.uses_xdrope_dim > 0:
            self._calc_xdrope_positions(scheduler_output)

1830
1831
1832
1833
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
1834
1835
1836
        token_indices = (
            positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
        )
1837
        token_indices_tensor = torch.from_numpy(token_indices)
1838

1839
1840
1841
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
1842
1843
1844
1845
1846
1847
        torch.index_select(
            self.input_batch.token_ids_cpu_tensor.flatten(),
            0,
            token_indices_tensor,
            out=self.input_ids.cpu[:total_num_scheduled_tokens],
        )
1848
        if self.enable_prompt_embeds:
1849
            is_token_ids = self.input_batch.is_token_ids_tensor.flatten()
1850
1851
1852
1853
            torch.index_select(
                is_token_ids,
                0,
                token_indices_tensor,
1854
1855
                out=self.is_token_ids.cpu[:total_num_scheduled_tokens],
            )
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888

        # Because we did not pre-allocate a massive prompt_embeds CPU tensor on
        # the InputBatch, we need to fill in the prompt embeds into the expected
        # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
        if self.input_batch.req_prompt_embeds:
            output_idx = 0
            for req_idx in range(num_reqs):
                num_sched = num_scheduled_tokens[req_idx]

                # Skip if this request doesn't have embeddings
                if req_idx not in self.input_batch.req_prompt_embeds:
                    output_idx += num_sched
                    continue

                # Skip if no tokens scheduled
                if num_sched <= 0:
                    output_idx += num_sched
                    continue

                req_embeds = self.input_batch.req_prompt_embeds[req_idx]
                start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]

                # Skip if trying to read beyond available embeddings
                if start_pos >= req_embeds.shape[0]:
                    output_idx += num_sched
                    continue

                # Copy available embeddings
                end_pos = start_pos + num_sched
                actual_end = min(end_pos, req_embeds.shape[0])
                actual_num_sched = actual_end - start_pos

                if actual_num_sched > 0:
1889
1890
1891
                    self.inputs_embeds.cpu[
                        output_idx : output_idx + actual_num_sched
                    ].copy_(req_embeds[start_pos:actual_end])
1892
1893

                output_idx += num_sched
1894

1895
1896
        self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
        self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
1897
1898

        # Prepare the attention metadata.
1899
        self.query_start_loc.np[0] = 0
1900
        self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
1901
1902
        # Note: pad query_start_loc to be non-decreasing, as kernels
        # like FlashAttention requires that
1903
        self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
1904
        self.query_start_loc.copy_to_gpu()
1905
        query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
1906

1907
        self.seq_lens.np[:num_reqs] = (
1908
1909
            self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
        )
1910
        # Fill unused with 0 for full cuda graph mode.
1911
1912
        self.seq_lens.np[num_reqs:].fill(0)
        self.seq_lens.copy_to_gpu()
1913

1914
        num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
1915
1916
        num_tokens_np = np.array(num_tokens, dtype=np.int32)

1917
        # Record which requests should not be sampled,
1918
        # so that we could clear the sampled tokens before returning
1919
1920
        self.discard_request_mask.np[:num_reqs] = (
            self.seq_lens.np[:num_reqs] < num_tokens_np
1921
        )
1922
        self.discard_request_mask.copy_to_gpu(num_reqs)
1923

1924
        # Copy the tensors to the GPU.
1925
1926
1927
1928
1929
        self._prepare_input_ids(
            scheduler_output,
            total_num_scheduled_tokens,
            cu_num_tokens,
        )
1930

1931
        if self.uses_mrope:
1932
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
guanyu1's avatar
guanyu1 committed
1933
            self._copy_mrope_positions_to_gpu(total_num_scheduled_tokens)
1934
1935
        elif self.uses_xdrope_dim > 0:
            # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
guanyu1's avatar
guanyu1 committed
1936
1937
            self._copy_xdrope_positions_to_gpu(total_num_scheduled_tokens)

1938
1939
        else:
            # Common case (1D positions)
1940
            self.positions.copy_to_gpu(total_num_scheduled_tokens)
1941

1942
        use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
1943
1944
1945
1946
1947
1948
1949
1950
        if not use_spec_decode:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1
            spec_decode_metadata = None
1951
            num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
1952
1953
1954
1955
1956
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
1957
1958
1959
            # For chunked prefills, use -1 as mask rather than 0, as guided
            # decoding may rollback speculative tokens.
            num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
1960
1961
1962
1963
            for (
                req_id,
                draft_token_ids,
            ) in scheduler_output.scheduled_spec_decode_tokens.items():
1964
1965
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)
1966
1967
1968
1969
1970
                if (
                    self.input_batch.num_computed_tokens_cpu[req_idx]
                    >= self.input_batch.num_prompt_tokens[req_idx]
                ):
                    num_decode_draft_tokens[req_idx] = len(draft_token_ids)
王敏's avatar
王敏 committed
1971
1972
1973
1974
1975

            spec_decode_ids = None
            if envs.VLLM_REJECT_SAMPLE_OPT:
                spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()

1976
            spec_decode_metadata = self._calc_spec_decode_metadata(
王敏's avatar
王敏 committed
1977
                num_draft_tokens, cu_num_tokens, spec_decode_ids
1978
            )
1979
            logits_indices = spec_decode_metadata.logits_indices
1980
            num_sampled_tokens = num_draft_tokens + 1
1981
            # For DECODE only cuda graph of some attention backends (e.g., GDN).
1982
            self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
1983
1984
            self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
            self.num_decode_draft_tokens.copy_to_gpu()
1985

1986
1987
1988
1989
1990
        # Hot-Swap lora model
        if self.lora_config:
            assert (
                np.sum(num_sampled_tokens)
                <= self.vllm_config.scheduler_config.max_num_batched_tokens
1991
            )
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
            self.set_active_loras(
                self.input_batch, num_scheduled_tokens, num_sampled_tokens
            )

        return (
            logits_indices,
            spec_decode_metadata,
        )

    def _build_attention_metadata(
        self,
2003
        num_tokens: int,
2004
        num_reqs: int,
2005
2006
2007
        max_query_len: int,
        num_tokens_padded: int | None = None,
        num_reqs_padded: int | None = None,
2008
2009
2010
2011
        ubatch_slices: UBatchSlices | None = None,
        logits_indices: torch.Tensor | None = None,
        use_spec_decode: bool = False,
        for_cudagraph_capture: bool = False,
2012
        num_scheduled_tokens: dict[str, int] | None = None,
2013
        cascade_attn_prefix_lens: list[list[int]] | None = None,
2014
        slot_mappings: dict[int, torch.Tensor] | None = None,
2015
2016
2017
2018
2019
2020
    ) -> tuple[
        PerLayerAttnMetadata,
        CommonAttentionMetadata | None,
        torch.Tensor | None,
        torch.Tensor | None,
    ]:
2021
2022
2023
        """
        :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
        """
2024
2025
        # Attention metadata is not needed for attention free models
        if len(self.kv_cache_config.kv_cache_groups) == 0:
2026
2027
2028
            return {}, None, None, None

        tp_size = self.vllm_config.parallel_config.tensor_parallel_size
2029

2030
2031
        num_tokens_padded = num_tokens_padded or num_tokens
        num_reqs_padded = num_reqs_padded or num_reqs
2032
        assert num_reqs_padded is not None and num_tokens_padded is not None
2033

2034
2035
2036
        attn_metadata: PerLayerAttnMetadata = {}
        if ubatch_slices is not None:
            attn_metadata = [dict() for _ in range(len(ubatch_slices))]
2037

2038
2039
2040
2041
2042
2043
2044
2045
        if for_cudagraph_capture:
            # For some attention backends (e.g. FA) with sliding window models we need
            # to make sure the backend see a max_seq_len that is larger to the sliding
            # window size when capturing to make sure the correct kernel is selected.
            max_seq_len = self.max_model_len
        else:
            max_seq_len = self.seq_lens.np[:num_reqs].max().item()

2046
2047
        if use_spec_decode:
            self.num_accepted_tokens.np[:num_reqs] = (
2048
2049
                self.input_batch.num_accepted_tokens_cpu[:num_reqs]
            )
2050
2051
            self.num_accepted_tokens.np[num_reqs:].fill(1)
            self.num_accepted_tokens.copy_to_gpu()
2052

2053
        kv_cache_groups = self.kv_cache_config.kv_cache_groups
2054

2055
        def _get_block_table(kv_cache_gid: int):
2056
2057
2058
            assert num_reqs_padded is not None and num_tokens_padded is not None
            kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
            if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
2059
                blk_table_tensor = torch.zeros(
2060
                    (num_reqs_padded, 1),
2061
                    dtype=torch.int32,
2062
2063
                    device=self.device,
                )
2064
            else:
2065
                blk_table = self.input_batch.block_table[kv_cache_gid]
2066
                blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
2067

2068
2069
2070
            # Fill unused with -1. Needed for reshape_and_cache in full cuda
            # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
            blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
2071
            return blk_table_tensor
2072

2073
2074
2075
        assert slot_mappings is not None
        block_table_gid_0 = _get_block_table(0)
        slot_mapping_gid_0 = slot_mappings[0]
2076
2077
        scatter_indexes_tensor = None
        gather_indexes_tensor = None
2078

2079
2080
        if self.model_config.enable_return_routed_experts:
            self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114

        enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.lightly_cp_threshould
        if not enable_lightly_cp:
            cm_base = CommonAttentionMetadata(
                query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
                query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
                seq_lens=self.seq_lens.gpu[:num_reqs_padded],
                _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
                _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
                    :num_reqs_padded
                ],
                num_reqs=num_reqs_padded,
                num_actual_tokens=num_tokens_padded,
                num_kv_actual_tokens=num_tokens_padded,
                max_query_len=max_query_len,
                max_seq_len=max_seq_len,
                block_table_tensor=block_table_gid_0,
                slot_mapping=slot_mapping_gid_0,
                causal=True,
            )
        else:
            cm_base = self._prepare_cp_metadata(
                num_reqs_padded,
                max_query_len,
                max_seq_len,
                num_tokens,
                block_table_gid_0,
                slot_mapping_gid_0,
                self.input_batch.num_computed_tokens_cpu_tensor[
                    :num_reqs_padded
                ],
            )
            scatter_indexes_tensor = cm_base.scatter_indexes_tensor
            gather_indexes_tensor = cm_base.gather_indexes_tensor
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136

        if self.dcp_world_size > 1:
            self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
                self.seq_lens.cpu[:num_reqs],
                self.dcp_world_size,
                self.dcp_rank,
                self.parallel_config.cp_kv_cache_interleave_size,
            )
            self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
            self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)

            cm_base.dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
            cm_base.dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[
                :num_reqs_padded
            ]

        if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
            cm_base.num_logits_indices = logits_indices.size(0)
            cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
                logits_indices
            )

2137
2138
2139
2140
2141
2142
2143
2144
2145
        # Cache attention metadata builds across hybrid KV-cache groups
        # The only thing that changes between different hybrid KV-cache groups when the
        # same metadata builder and KVCacheSpec is the same is the block table, so we
        # can cache the attention metadata builds and just update the block table using
        # `builder.update_block_table` if the builder supports it.
        cached_attn_metadata: dict[
            tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
        ] = {}

2146
2147
2148
2149
2150
2151
2152
        def _build_attn_group_metadata(
            kv_cache_gid: int,
            attn_gid: int,
            common_attn_metadata: CommonAttentionMetadata,
            ubid: int | None = None,
        ) -> None:
            attn_group = self.attn_groups[kv_cache_gid][attn_gid]
2153
            builder = attn_group.get_metadata_builder(ubid or 0)
2154
2155
2156
2157
            kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
            if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
                kv_cache_spec = kv_cache_spec.kv_cache_specs[attn_group.layer_names[0]]
            cache_key = (kv_cache_spec, type(builder))
2158

2159
2160
2161
2162
            cascade_attn_prefix_len = (
                cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
                if cascade_attn_prefix_lens
                else 0
2163
2164
            )

2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
            extra_attn_metadata_args = {}
            if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
                assert ubid is None, "UBatching not supported with GDN yet"
                extra_attn_metadata_args = dict(
                    num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded],
                    num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
                        :num_reqs_padded
                    ],
                )

            if for_cudagraph_capture:
                attn_metadata_i = builder.build_for_cudagraph_capture(
                    common_attn_metadata
                )
2179
2180
2181
2182
2183
2184
2185
2186
2187
            elif (
                cache_key in cached_attn_metadata
                and builder.supports_update_block_table
            ):
                attn_metadata_i = builder.update_block_table(
                    cached_attn_metadata[cache_key],
                    common_attn_metadata.block_table_tensor,
                    common_attn_metadata.slot_mapping,
                )
2188
2189
2190
2191
2192
2193
            else:
                attn_metadata_i = builder.build(
                    common_prefix_len=cascade_attn_prefix_len,
                    common_attn_metadata=common_attn_metadata,
                    **extra_attn_metadata_args,
                )
2194
2195
                if builder.supports_update_block_table:
                    cached_attn_metadata[cache_key] = attn_metadata_i
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218

            if ubid is None:
                assert isinstance(attn_metadata, dict)
                attn_metadata_dict = attn_metadata
            else:
                assert isinstance(attn_metadata, list)
                attn_metadata_dict = attn_metadata[ubid]

            for layer_name in attn_group.layer_names:
                attn_metadata_dict[layer_name] = attn_metadata_i

        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
        spec_decode_common_attn_metadata = None
        for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups):
            cm = copy(cm_base)  # shallow copy

            # Basically only the encoder seq_lens, block_table and slot_mapping change
            # for each kv_cache_group.
            cm.encoder_seq_lens, cm.encoder_seq_lens_cpu = self._get_encoder_seq_lens(
                num_scheduled_tokens or {},
                kv_cache_group.kv_cache_spec,
                num_reqs_padded,
2219
                for_cudagraph_capture=for_cudagraph_capture,
2220
            )
2221
            if kv_cache_gid > 0:
2222
2223
                cm.block_table_tensor = _get_block_table(kv_cache_gid)
                cm.slot_mapping = slot_mappings[kv_cache_gid]
2224

2225
2226
2227
                if enable_lightly_cp and cm.seq_indexes_list is not None:
                    cm.block_table_tensor = cm.block_table_tensor[cm.seq_indexes_list]

王敏's avatar
王敏 committed
2228
            if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
2229
                if isinstance(self.drafter, EagleProposer):
2230
                    if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
2231
2232
2233
2234
2235
                        if enable_lightly_cp:
                            spec_decode_common_attn_metadata = cm.cp_common_metadata
                        else:
                            spec_decode_common_attn_metadata = cm
                        #spec_decode_common_attn_metadata = cm
2236
                else:
2237
2238
2239
2240
2241
                    if enable_lightly_cp:
                        spec_decode_common_attn_metadata = cm.cp_common_metadata
                    else:
                        spec_decode_common_attn_metadata = cm
                    #spec_decode_common_attn_metadata = cm
2242

2243
            for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
2244
                if ubatch_slices is not None:
2245
2246
2247
                    for ubid, _cm in enumerate(split_attn_metadata(ubatch_slices, cm)):
                        _build_attn_group_metadata(kv_cache_gid, attn_gid, _cm, ubid)

2248
                else:
2249
                    _build_attn_group_metadata(kv_cache_gid, attn_gid, cm)
2250

2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
        if self.is_mm_prefix_lm:
            req_doc_ranges = {}
            for req_id in self.input_batch.req_ids:
                image_doc_ranges = []
                req_state = self.requests[req_id]
                for mm_feature in req_state.mm_features:
                    pos_info = mm_feature.mm_position
                    img_doc_range = pos_info.extract_embeds_range()
                    image_doc_ranges.extend(img_doc_range)
                req_idx = self.input_batch.req_id_to_index[req_id]
                req_doc_ranges[req_idx] = image_doc_ranges

            if isinstance(attn_metadata, list):
                for ub_metadata in attn_metadata:
                    for _metadata in ub_metadata.values():
                        _metadata.mm_prefix_range = req_doc_ranges  # type: ignore[attr-defined]
            else:
                for _metadata in attn_metadata.values():
                    _metadata.mm_prefix_range = req_doc_ranges  # type: ignore[attr-defined]
2270

2271
2272
2273
2274
        if (
            (not self.enable_lightly_cp)
            and spec_decode_common_attn_metadata is not None
            and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded)
2275
2276
2277
2278
2279
2280
2281
2282
        ):
            # Currently the drafter still only uses piecewise cudagraphs (and modifies
            # the attention metadata in directly), and therefore does not want to use
            # padded attention metadata.
            spec_decode_common_attn_metadata = (
                spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
            )

2283
2284
2285
2286
2287
2288
        return (
            attn_metadata, 
            spec_decode_common_attn_metadata,
            scatter_indexes_tensor,
            gather_indexes_tensor
        )
2289

2290
2291
2292
    def _compute_cascade_attn_prefix_lens(
        self,
        num_scheduled_tokens: np.ndarray,
2293
        num_computed_tokens: np.ndarray,
2294
2295
2296
2297
2298
2299
2300
        num_common_prefix_blocks: list[int],
    ) -> list[list[int]] | None:
        """
        :return: Optional[cascade_attn_prefix_lens]
            cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
            None if we should not use cascade attention
        """
2301

2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
        use_cascade_attn = False
        num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
        cascade_attn_prefix_lens: list[list[int]] = [
            [] for _ in range(num_kv_cache_groups)
        ]

        for kv_cache_gid in range(num_kv_cache_groups):
            for attn_group in self.attn_groups[kv_cache_gid]:
                if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec):
                    cascade_attn_prefix_len = 0
                else:
                    # 0 if cascade attention should not be used
                    cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
                        num_scheduled_tokens,
2316
                        num_computed_tokens,
2317
2318
2319
2320
2321
2322
                        num_common_prefix_blocks[kv_cache_gid],
                        attn_group.kv_cache_spec,
                        attn_group.get_metadata_builder(),
                    )
                cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len)
                use_cascade_attn |= cascade_attn_prefix_len > 0
2323

2324
        return cascade_attn_prefix_lens if use_cascade_attn else None
2325

2326
2327
2328
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
2329
        num_computed_tokens: np.ndarray,
2330
        num_common_prefix_blocks: int,
2331
2332
        kv_cache_spec: KVCacheSpec,
        attn_metadata_builder: AttentionMetadataBuilder,
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
2351

2352
        common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
2390
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
2391
2392
2393
2394
2395
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
2396
        common_prefix_len = min(common_prefix_len, num_computed_tokens.min())
2397
        # common_prefix_len should be a multiple of the block size.
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
        common_prefix_len = (
            common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size
        )
        use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or (
            isinstance(kv_cache_spec, FullAttentionSpec)
            and kv_cache_spec.sliding_window is not None
        )
        use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or (
            isinstance(kv_cache_spec, FullAttentionSpec)
            and kv_cache_spec.attention_chunk_size is not None
        )
2409
2410
        assert isinstance(kv_cache_spec, AttentionSpec)
        use_cascade = attn_metadata_builder.use_cascade_attention(
2411
2412
2413
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
2414
            num_kv_heads=kv_cache_spec.num_kv_heads,
2415
            use_alibi=self.use_alibi,
2416
            use_sliding_window=use_sliding_window,
2417
            use_local_attention=use_local_attention,
2418
            num_sms=self.num_sms,
2419
            dcp_world_size=self.dcp_world_size,
2420
2421
2422
        )
        return common_prefix_len if use_cascade else 0

guanyu1's avatar
guanyu1 committed
2423
2424
    def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
        xdrope_pos_ptr = 0
2425
        for index, req_id in enumerate(self.input_batch.req_ids):
2426
            req = self.requests[req_id]
guanyu1's avatar
guanyu1 committed
2427
            assert req.xdrope_positions is not None
2428

2429
2430
            num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
2431
            num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
2432
2433
                req.prompt_token_ids, req.prompt_embeds
            )
2434
2435

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
2436
2437
                prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
2438
2439
2440
2441
2442
2443
2444
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
guanyu1's avatar
guanyu1 committed
2445
2446
2447
                # prompt's xdrope_positions are pre-computed
                dst_start = xdrope_pos_ptr
                dst_end = xdrope_pos_ptr + prompt_part_len
2448
2449
2450
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

guanyu1's avatar
guanyu1 committed
2451
                self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
2452
2453
                    :, src_start:src_end
                ]
guanyu1's avatar
guanyu1 committed
2454
                xdrope_pos_ptr += prompt_part_len
2455
2456

            if completion_part_len > 0:
guanyu1's avatar
guanyu1 committed
2457
2458
2459
                # compute completion's xdrope_positions on-the-fly
                dst_start = xdrope_pos_ptr
                dst_end = xdrope_pos_ptr + completion_part_len
2460

guanyu1's avatar
guanyu1 committed
2461
2462
                XDRotaryEmbedding.get_next_input_positions_tensor(
                    out=self.xdrope_positions.np,
2463
2464
2465
2466
                    out_offset=dst_start,
                    context_len=num_computed_tokens + prompt_part_len,
                    num_new_tokens=completion_part_len,
                )
2467

guanyu1's avatar
guanyu1 committed
2468
2469
                xdrope_pos_ptr += completion_part_len
                
guanyu1's avatar
guanyu1 committed
2470
2471
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
guanyu1's avatar
guanyu1 committed
2472
2473
2474
2475
2476
2477
2478
        if self.use_1d_mrope:
            mrope_positions_token_major = self.mrope_positions.cpu.view(
                self.max_num_tokens + 1, 3
            )
            mrope_positions_token_major_np = self.mrope_positions.np.reshape(
                self.max_num_tokens + 1, 3
            )
2479
2480
        for index, req_id in enumerate(self.input_batch.req_ids):
            req = self.requests[req_id]
guanyu1's avatar
guanyu1 committed
2481
            assert req.mrope_positions is not None
2482

2483
2484
2485
2486
2487
            num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
                req.prompt_token_ids, req.prompt_embeds
            )
2488

2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
guanyu1's avatar
guanyu1 committed
2499
2500
2501
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
2502
2503
2504
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

guanyu1's avatar
guanyu1 committed
2505
2506
2507
2508
2509
2510
2511
2512
                if self.use_1d_mrope:
                    mrope_positions_token_major[dst_start:dst_end, :].copy_(
                        req.mrope_positions[:, src_start:src_end].transpose(0, 1)
                    )
                else:
                    self.mrope_positions.cpu[:, dst_start:dst_end] = (
                        req.mrope_positions[:, src_start:src_end]
                    )
guanyu1's avatar
guanyu1 committed
2513
                mrope_pos_ptr += prompt_part_len
2514
2515

            if completion_part_len > 0:
guanyu1's avatar
guanyu1 committed
2516
2517
2518
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len
2519

guanyu1's avatar
guanyu1 committed
2520
                assert req.mrope_position_delta is not None
guanyu1's avatar
guanyu1 committed
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
                if self.use_1d_mrope:
                    values = np.arange(
                        req.mrope_position_delta
                        + num_computed_tokens
                        + prompt_part_len,
                        req.mrope_position_delta
                        + num_computed_tokens
                        + prompt_part_len
                        + completion_part_len,
                        dtype=mrope_positions_token_major_np.dtype,
                    )
                    mrope_positions_token_major_np[dst_start:dst_end, :] = values[
                        :, None
                    ]
                else:
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        out=self.mrope_positions.np,
                        out_offset=dst_start,
                        mrope_position_delta=req.mrope_position_delta,
                        context_len=num_computed_tokens + prompt_part_len,
                        num_new_tokens=completion_part_len,
                    )
2543

guanyu1's avatar
guanyu1 committed
2544
                mrope_pos_ptr += completion_part_len
2545

2546
2547
    def _calc_spec_decode_metadata(
        self,
2548
2549
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
王敏's avatar
王敏 committed
2550
        spec_decode_ids: Optional[list[str]] = None
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
2565
2566
2567
2568

        # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
        # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
2569
2570
            num_sampled_tokens, cumsum_dtype=np.int32
        )
2571
        # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
2572
        logits_indices = np.repeat(
2573
2574
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens
        )
2575
        # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
2576
2577
2578
2579
2580
2581
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
2582
2583
2584
        # cu_num_draft_tokens: [3, 3, 5, 5, 6]
        # arange: [0, 1, 2, 0, 1, 0]
        cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
2585
2586
            num_draft_tokens, cumsum_dtype=np.int32
        )
2587
2588
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
2589
2590
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens
        )
2591
2592
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange
2593
        draft_token_indices = target_logits_indices + 1
2594

2595
        # TODO: Optimize the CPU -> GPU copy.
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
        # cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
        #     self.device, non_blocking=True
        # )
        # cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
        #     self.device, non_blocking=True
        # )
        # logits_indices = torch.from_numpy(logits_indices).to(
        #     self.device, non_blocking=True
        # )
        # target_logits_indices = torch.from_numpy(target_logits_indices).to(
        #     self.device, non_blocking=True
        # )
        # bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
        #     self.device, non_blocking=True
        # )

        # # Compute the draft token ids.
        # # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        # draft_token_ids = self.input_ids.gpu[logits_indices]
        # draft_token_ids = draft_token_ids[target_logits_indices + 1]

        # Optimize the H2D in the process of creating spec decode metadata
        fused_meta_data = cu_num_draft_tokens.tolist() + cu_num_sampled_tokens.tolist()\
              + logits_indices.tolist() + target_logits_indices.tolist() + bonus_logits_indices.tolist()\
              + draft_token_indices.tolist()
        
        fused_meta_data_len = np.array([len(cu_num_draft_tokens), len(cu_num_sampled_tokens),\
                                        len(logits_indices), len(target_logits_indices),\
                                            len(bonus_logits_indices), len(draft_token_indices)], dtype=np.int32)
        cu_fused_meta_data_len = np.cumsum(fused_meta_data_len, dtype=np.int32)
        fused_meta_data = torch.tensor(
            fused_meta_data, dtype=torch.int32, pin_memory=self.pin_memory
        ).to(self.device, non_blocking=True)
        
        cu_num_draft_tokens = fused_meta_data[:cu_fused_meta_data_len[0]]
        cu_num_sampled_tokens = fused_meta_data[cu_fused_meta_data_len[0]:cu_fused_meta_data_len[1]]
        logits_indices = fused_meta_data[cu_fused_meta_data_len[1]:cu_fused_meta_data_len[2]]
        target_logits_indices = fused_meta_data[cu_fused_meta_data_len[2]:cu_fused_meta_data_len[3]]
        bonus_logits_indices = fused_meta_data[cu_fused_meta_data_len[3]:cu_fused_meta_data_len[4]]
        draft_token_indices = fused_meta_data[cu_fused_meta_data_len[4]:cu_fused_meta_data_len[5]]

2637

2638
2639
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
2640
        draft_token_ids = self.input_ids.gpu[logits_indices]
2641
        draft_token_ids = draft_token_ids[draft_token_indices]
2642

2643
        return SpecDecodeMetadata(
2644
2645
2646
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
2647
            cu_num_sampled_tokens=cu_num_sampled_tokens,
2648
2649
2650
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
王敏's avatar
王敏 committed
2651
            spec_decode_ids=spec_decode_ids,
2652
2653
        )

2654
2655
2656
2657
2658
2659
2660
    def _prepare_kv_sharing_fast_prefill(
        self,
        logits_indices: torch.Tensor,
    ) -> torch.Tensor:
        assert self.kv_sharing_fast_prefill_logits_indices is not None
        num_logits = logits_indices.shape[0]
        assert num_logits > 0
2661
        self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices)
2662
2663
2664
2665
2666
        # There might have leftover indices in logits_indices[num_logits:]
        # from previous iterations, whose values may be greater than the
        # batch size in the current iteration. To ensure indices are always
        # valid, we fill the padded indices with the last index.
        self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
2667
2668
            logits_indices[-1].item()
        )
2669
2670
2671
2672
2673
        # Dispatch for the decoder portion of the model.
        _, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_logits, disable_full=True
        )
        num_logits_padded = batch_desc.num_tokens
2674
2675
2676
        logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[
            :num_logits_padded
        ]
2677
2678
        return logits_indices_padded

2679
    def _batch_mm_inputs_from_scheduler(
2680
2681
        self,
        scheduler_output: "SchedulerOutput",
2682
2683
2684
2685
2686
    ) -> tuple[
        list[str],
        list[MultiModalKwargsItem],
        list[tuple[str, PlaceholderRange]],
    ]:
2687
        """Batch multimodal inputs from scheduled encoder inputs.
2688
2689
2690

        Args:
            scheduler_output: The scheduler output containing scheduled encoder
2691
                inputs.
2692
2693

        Returns:
2694
            A tuple of (mm_hashes, mm_kwargs, mm_lora_refs) where:
2695
2696
            - mm_hashes: List of multimodal hashes for each item
            - mm_kwargs: List of multimodal kwargs for each item
2697
            - mm_lora_refs: List of (req_id, placeholder_range) for each item
2698
        """
2699
2700
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
2701
            return [], [], []
2702
2703

        mm_hashes = list[str]()
2704
        mm_kwargs = list[MultiModalKwargsItem]()
2705
2706
2707
        # Multimodal LoRA reference info to map each multimodal item
        # back to its request & position
        mm_lora_refs = list[tuple[str, PlaceholderRange]]()
2708
2709
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
2710
2711

            for mm_input_id in encoder_input_ids:
2712
                mm_feature = req_state.mm_features[mm_input_id]
2713
2714
                if mm_feature.data is None:
                    continue
2715
2716

                mm_hashes.append(mm_feature.identifier)
2717
                mm_kwargs.append(mm_feature.data)
2718
                mm_lora_refs.append((req_id, mm_feature.mm_position))
2719

2720
        return mm_hashes, mm_kwargs, mm_lora_refs
2721

2722
2723
2724
    def _execute_mm_encoder(
        self, scheduler_output: "SchedulerOutput"
    ) -> list[torch.Tensor]:
2725
        mm_hashes, mm_kwargs, mm_lora_refs = self._batch_mm_inputs_from_scheduler(
2726
2727
            scheduler_output
        )
2728
2729

        if not mm_kwargs:
2730
            return []
2731

2732
2733
2734
2735
2736
2737
        should_time = bool(
            self.observability_config
            and self.observability_config.enable_mm_processor_stats
            and scheduler_output.scheduled_encoder_inputs
        )

2738
2739
2740
2741
2742
2743
2744
        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
2745
        model = cast(SupportsMultiModal, self.model)
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802

        if self.lora_config and self.lora_manager.supports_tower_connector_lora():
            # Build LoRA mappings independently for encoder inputs
            # (encoder batch structure is different from main batch)
            prompt_lora_mapping = []
            token_lora_mapping = []
            lora_requests = set()
            encoder_token_counts = []

            for req_id, pos_info in mm_lora_refs:
                req_idx = self.input_batch.req_id_to_index[req_id]
                lora_id = int(self.input_batch.request_lora_mapping[req_idx])

                # Prefer pos_info.get_num_embeds to count precise MM embedding tokens.
                num_tokens = self.model.get_num_mm_encoder_tokens(  # type: ignore[attr-defined]
                    pos_info.get_num_embeds
                )
                prompt_lora_mapping.append(lora_id)
                token_lora_mapping.extend([lora_id] * num_tokens)
                encoder_token_counts.append(num_tokens)

                if lora_id > 0:
                    lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id)
                    if lora_request is not None:
                        lora_requests.add(lora_request)

            # Set tower adapter mapping
            tower_mapping = LoRAMapping(
                tuple(token_lora_mapping),
                tuple(prompt_lora_mapping),
                is_prefill=True,
                type=LoRAMappingType.TOWER,
            )
            self.lora_manager.set_active_adapters(lora_requests, tower_mapping)

            if hasattr(self.model, "get_num_mm_connector_tokens"):
                post_op_counts = [
                    self.model.get_num_mm_connector_tokens(num_tokens)  # type: ignore[attr-defined]
                    for num_tokens in encoder_token_counts
                ]

                connector_token_mapping = np.repeat(
                    np.array(prompt_lora_mapping, dtype=np.int32),
                    np.array(post_op_counts, dtype=np.int32),
                )
                connector_mapping = LoRAMapping(
                    index_mapping=tuple(connector_token_mapping.tolist()),
                    prompt_mapping=tuple(prompt_lora_mapping),
                    is_prefill=True,
                    type=LoRAMappingType.CONNECTOR,
                )

                self.lora_manager.set_active_adapters(
                    lora_requests,
                    connector_mapping,
                )

2803
        encoder_outputs: list[torch.Tensor] = []
2804
2805
        # Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs
        current_item_idx = 0
2806
        for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
2807
2808
2809
            mm_kwargs,
            device=self.device,
            pin_memory=self.pin_memory,
2810
        ):
2811
            curr_group_outputs: MultiModalEmbeddings
2812
2813

            # EVS-related change.
2814
            # (ekhvedchenia): Temporary hack to limit peak memory usage when
2815
            # processing multimodal data. This solves the issue with scheduler
2816
2817
2818
2819
            # putting too many video samples into a single batch. Scheduler
            # uses pruned vision tokens count to compare it versus compute
            # budget which is incorrect (Either input media size or non-pruned
            # output vision tokens count should be considered)
2820
2821
2822
2823
2824
2825
2826
            # TODO(ywang96): Fix memory profiling to take EVS into account and
            # remove this hack.
            if (
                self.is_multimodal_pruning_enabled
                and modality == "video"
                and num_items > 1
            ):
2827
                curr_group_outputs_lst = list[torch.Tensor]()
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
                for video_idx in range(num_items):
                    video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx]
                    with self.timed_encoder_operation(
                        should_time, mm_lora_refs, current_item_idx + video_idx, 1
                    ):
                        _, _, micro_batch_mm_inputs = next(
                            group_mm_kwargs_by_modality(
                                [video_mm_kwargs_item],
                                device=self.device,
                                pin_memory=self.pin_memory,
                            )
2839
                        )
2840

2841
2842
2843
                        micro_batch_outputs = model.embed_multimodal(
                            **micro_batch_mm_inputs
                        )
2844

2845
                        curr_group_outputs_lst.extend(micro_batch_outputs)
2846
2847

                curr_group_outputs = curr_group_outputs_lst
2848
2849
2850
2851
2852
2853
2854
2855
            else:
                # Run the encoder.
                # `curr_group_outputs` is either of the following:
                # 1. A tensor of shape (num_items, feature_size, hidden_size)
                # in case feature_size is fixed across all multimodal items.
                # 2. A list or tuple (length: num_items) of tensors,
                # each of shape (feature_size, hidden_size) in case the feature
                # size is dynamic depending on the input multimodal items.
2856
2857
2858
2859
2860

                with self.timed_encoder_operation(
                    should_time, mm_lora_refs, current_item_idx, num_items
                ):
                    curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
2861

2862
2863
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
2864
                expected_num_items=num_items,
2865
            )
2866
            encoder_outputs.extend(curr_group_outputs)
2867

2868
2869
            current_item_idx += num_items

2870
        # Cache the encoder outputs by mm_hash
2871
        for mm_hash, output in zip(mm_hashes, encoder_outputs):
2872
            self.encoder_cache[mm_hash] = output
2873
2874
            logger.debug("Finish execute for mm hash %s", mm_hash)
            self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
2875

2876
        return encoder_outputs
2877
2878

    def _gather_mm_embeddings(
2879
2880
        self,
        scheduler_output: "SchedulerOutput",
2881
        shift_computed_tokens: int = 0,
2882
2883
2884
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens

2885
2886
2887
2888
2889
        # Swap to the other buffer to avoid race condition with previous
        # iteration's async copy that may still be reading from CPU.
        self.is_mm_embed_idx = 1 - self.is_mm_embed_idx
        is_mm_embed_buf = self.is_mm_embed_buffers[self.is_mm_embed_idx]

2890
        mm_embeds = list[torch.Tensor]()
2891
        is_mm_embed = is_mm_embed_buf.cpu
2892
2893
2894
        is_mm_embed[:total_num_scheduled_tokens] = False

        req_start_idx = 0
2895
        should_sync_mrope_positions = False
2896
        should_sync_xdrope_positions = False
2897

2898
        for req_id in self.input_batch.req_ids:
2899
2900
            mm_embeds_req: list[torch.Tensor] = []

2901
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
2902
            req_state = self.requests[req_id]
2903
            num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
2904

2905
2906
            for mm_feature in req_state.mm_features:
                pos_info = mm_feature.mm_position
2907
2908
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
2925
2926
                    num_encoder_tokens,
                )
2927
                assert start_idx < end_idx
2928
2929
2930
2931
2932
2933
2934
                curr_embeds_start, curr_embeds_end = (
                    pos_info.get_embeds_indices_in_range(start_idx, end_idx)
                )
                # If there are no embeddings in the current range, we skip
                # gathering the embeddings.
                if curr_embeds_start == curr_embeds_end:
                    continue
2935

2936
                mm_hash = mm_feature.identifier
2937
                encoder_output = self.encoder_cache.get(mm_hash, None)
2938
                assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
2939
2940
2941

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]
2942
2943
2944
                    mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
                else:
                    mm_embeds_item = encoder_output[start_idx:end_idx]
2945

2946
                req_start_pos = req_start_idx + start_pos - num_computed_tokens
2947
2948
2949
2950
2951
2952
2953
2954
2955
                # OR mask for overlapping mm_features (use_audio_in_video)
                if is_embed is None:
                    is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
                        True
                    )
                else:
                    is_mm_embed[
                        req_start_pos + start_idx : req_start_pos + end_idx
                    ] |= is_embed
2956
2957
2958
                mm_embeds_req.append(mm_embeds_item)

            if self.is_multimodal_pruning_enabled and self.uses_mrope:
2959
                assert req_state.mrope_positions is not None
2960
2961
2962
2963
2964
2965
2966
                should_sync_mrope_positions = True
                mm_embeds_req, new_mrope_positions, new_delta = (
                    self.model.recompute_mrope_positions(
                        input_ids=req_state.prompt_token_ids,
                        multimodal_embeddings=mm_embeds_req,
                        mrope_positions=req_state.mrope_positions,
                        num_computed_tokens=req_state.num_computed_tokens,
2967
2968
                    )
                )
2969
2970
2971
2972
                req_state.mrope_positions.copy_(new_mrope_positions)
                req_state.mrope_position_delta = new_delta

            mm_embeds.extend(mm_embeds_req)
2973
2974
            req_start_idx += num_scheduled_tokens

2975
        is_mm_embed = is_mm_embed_buf.copy_to_gpu(total_num_scheduled_tokens)
2976
2977
2978

        if should_sync_mrope_positions:
            self._calc_mrope_positions(scheduler_output)
guanyu1's avatar
guanyu1 committed
2979
            self._copy_mrope_positions_to_gpu(total_num_scheduled_tokens)
2980

2981
2982
        if should_sync_xdrope_positions:
            self._calc_xdrope_positions(scheduler_output)
guanyu1's avatar
guanyu1 committed
2983
            self._copy_xdrope_positions_to_gpu(total_num_scheduled_tokens)
2984

2985
        return mm_embeds, is_mm_embed
2986

2987
    def get_model(self) -> nn.Module:
2988
        # get raw model out of the cudagraph wrapper.
2989
        if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
2990
            return self.model.unwrap()
2991
2992
        return self.model

2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
    def get_supported_generation_tasks(self) -> list[GenerationTask]:
        model = self.get_model()
        supported_tasks = list[GenerationTask]()

        if is_text_generation_model(model):
            supported_tasks.append("generate")

        if supports_transcription(model):
            if model.supports_transcription_only:
                return ["transcription"]

            supported_tasks.append("transcription")

        return supported_tasks

3008
3009
3010
3011
3012
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

3013
3014
        supported_tasks = list(model.pooler.get_supported_tasks())

3015
3016
3017
3018
        if "score" in supported_tasks:
            num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
            if num_labels != 1:
                supported_tasks.remove("score")
3019
                logger.debug_once("Score API is only enabled for num_labels == 1.")
3020
3021

        return supported_tasks
3022

3023
3024
3025
3026
3027
3028
3029
3030
3031
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks = list[SupportedTask]()

        if self.model_config.runner_type == "generate":
            tasks.extend(self.get_supported_generation_tasks())
        if self.model_config.runner_type == "pooling":
            tasks.extend(self.get_supported_pooling_tasks())

        return tuple(tasks)
3032

3033
    def sync_and_slice_intermediate_tensors(
3034
3035
        self,
        num_tokens: int,
3036
        intermediate_tensors: IntermediateTensors | None,
3037
3038
        sync_self: bool,
    ) -> IntermediateTensors:
3039
3040
3041
        assert self.intermediate_tensors is not None

        tp = self.vllm_config.parallel_config.tensor_parallel_size
3042
        is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens)
3043
3044
3045
3046
3047
3048

        # When sequence parallelism is enabled, the "residual" tensor is sharded
        # across tensor parallel ranks, so each rank only needs its own slice.
        if sync_self:
            assert intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
3049
                is_scattered = k == "residual" and is_rs
3050
                copy_len = num_tokens // tp if is_scattered else num_tokens
3051
                self.intermediate_tensors[k][:copy_len].copy_(
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
                    v[:copy_len], non_blocking=True
                )

        return IntermediateTensors(
            {
                k: v[: num_tokens // tp]
                if k == "residual" and is_rs
                else v[:num_tokens]
                for k, v in self.intermediate_tensors.items()
            }
        )

    def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
3065
3066
3067
3068
3069
3070
3071
        """
        Step for the EPLB (Expert Parallelism Load Balancing) state.
        """
        if not self.parallel_config.enable_eplb:
            return

        assert self.eplb_state is not None
3072
3073
        model = self.get_model()
        assert is_mixture_of_experts(model)
3074
3075
3076
        self.eplb_state.step(
            is_dummy,
            is_profile,
3077
            log_stats=self.parallel_config.eplb_config.log_balancedness,
3078
3079
        )

3080
3081
3082
3083
3084
    def _pool(
        self,
        hidden_states: torch.Tensor,
        num_scheduled_tokens: int,
        num_scheduled_tokens_np: np.ndarray,
3085
3086
3087
3088
        kv_connector_output: KVConnectorOutput | None,
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        num_reqs = self.input_batch.num_reqs
        assert num_reqs == len(self.input_batch.pooling_params), (
3089
3090
            "Either all or none of the requests in a batch must be pooling request"
        )
3091

3092
        hidden_states = hidden_states[:num_scheduled_tokens]
3093
        seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
3094

3095
        pooling_metadata = self.input_batch.get_pooling_metadata()
3096
        pooling_metadata.build_pooling_cursor(
3097
            num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device
3098
        )
3099

3100
3101
        model = cast(VllmModelForPooling, self.model)
        raw_pooler_output: PoolerOutput = model.pooler(
3102
            hidden_states=hidden_states, pooling_metadata=pooling_metadata
3103
        )
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127

        finished_mask = [
            seq_len == prompt_len
            for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens)
        ]

        model_runner_output = ModelRunnerOutput(
            req_ids=self.input_batch.req_ids.copy(),
            req_id_to_index=self.input_batch.req_id_to_index.copy(),
            kv_connector_output=kv_connector_output,
        )

        if raw_pooler_output is None or not any(finished_mask):
            model_runner_output.pooler_output = [None] * num_reqs
            return model_runner_output

        if self.use_async_scheduling:
            return AsyncGPUPoolingModelRunnerOutput(
                model_runner_output=model_runner_output,
                raw_pooler_output=raw_pooler_output,
                finished_mask=finished_mask,
                async_output_copy_stream=self.async_output_copy_stream,
            )

3128
        raw_pooler_output = json_map_leaves(
3129
            lambda x: None if x is None else x.to("cpu", non_blocking=True),
3130
3131
            raw_pooler_output,
        )
3132
3133
3134
3135
        model_runner_output.pooler_output = [
            out if include else None
            for out, include in zip(raw_pooler_output, finished_mask)
        ]
3136
        self._sync_device()
3137

3138
        return model_runner_output
3139

3140
3141
3142
3143
3144
3145
3146
3147
    def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
        tp_size = self.vllm_config.parallel_config.tensor_parallel_size
        # if num_scheduled_tokens <= tp_size * tp_size:
        #     return num_scheduled_tokens
        # else:
        #     return round_up(num_scheduled_tokens, tp_size)
        return round_up(num_scheduled_tokens, tp_size)

3148
    def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
3149
3150
        # Pad tokens to multiple of tensor_parallel_size when
        # enabled collective fusion for SP
3151
3152
3153

        if self.enable_lightly_cp and num_scheduled_tokens > self.lightly_cp_threshould:
            return self._pad_for_mla_cp(num_scheduled_tokens)
3154
        tp_size = self.vllm_config.parallel_config.tensor_parallel_size
3155
        if self.compilation_config.pass_config.enable_sp and tp_size > 1:
3156
3157
3158
            return round_up(num_scheduled_tokens, tp_size)
        return num_scheduled_tokens

Patrick von Platen's avatar
Patrick von Platen committed
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
    def _prepare_mm_inputs(
        self, num_tokens: int
    ) -> tuple[torch.Tensor | None, torch.Tensor]:
        if self.model.requires_raw_input_tokens:
            input_ids = self.input_ids.gpu[:num_tokens]
        else:
            input_ids = None

        inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
        return input_ids, inputs_embeds

3170
    def _preprocess(
3171
3172
        self,
        scheduler_output: "SchedulerOutput",
3173
        num_input_tokens: int,  # Padded
3174
        intermediate_tensors: IntermediateTensors | None = None,
3175
    ) -> tuple[
3176
3177
        torch.Tensor | None,
        torch.Tensor | None,
3178
        torch.Tensor,
3179
        IntermediateTensors | None,
3180
        dict[str, Any],
3181
        ECConnectorOutput | None,
3182
    ]:
3183
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
3184
        is_first_rank = get_pp_group().is_first_rank
3185
        is_encoder_decoder = self.model_config.is_encoder_decoder
3186

3187
3188
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
3189
3190
        ec_connector_output = None

3191
        if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder:
3192
            # Run the multimodal encoder if any.
3193
3194
3195
3196
3197
3198
            with self.maybe_get_ec_connector_output(
                scheduler_output,
                encoder_cache=self.encoder_cache,
            ) as ec_connector_output:
                self._execute_mm_encoder(scheduler_output)
                mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
3199

3200
3201
3202
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
3203
            inputs_embeds_scheduled = self.model.embed_input_ids(
3204
3205
3206
                self.input_ids.gpu[:num_scheduled_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
3207
            )
3208

3209
            # TODO(woosuk): Avoid the copy. Optimize.
3210
            self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
3211

Patrick von Platen's avatar
Patrick von Platen committed
3212
            input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens)
3213
            model_kwargs = {
3214
                **self._init_model_kwargs(),
3215
3216
                **self._extract_mm_kwargs(scheduler_output),
            }
3217
        elif self.enable_prompt_embeds and is_first_rank:
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
            # Get the input embeddings for the tokens that are not input embeds,
            # then put them into the appropriate positions.
            # TODO(qthequartermasterman): Since even when prompt embeds are
            # enabled, (a) not all requests will use prompt embeds, and (b)
            # after the initial prompt is processed, the rest of the generated
            # tokens will be token ids, it is not desirable to have the
            # embedding layer outside of the CUDA graph all the time. The v0
            # engine avoids this by "double compiling" the CUDA graph, once
            # with input_ids and again with inputs_embeds, for all num_tokens.
            # If a batch only has token ids, then including the embedding layer
            # in the CUDA graph will be more performant (like in the else case
            # below).
3230
3231
3232
            token_ids_idx = (
                self.is_token_ids.gpu[:num_scheduled_tokens]
                .nonzero(as_tuple=False)
3233
                .squeeze(1)
3234
            )
3235
3236
3237
            # Some tokens ids may need to become embeds
            if token_ids_idx.numel() > 0:
                token_ids = self.input_ids.gpu[token_ids_idx]
3238
                tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids)
3239
3240
3241
                self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds

            inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
3242
            model_kwargs = self._init_model_kwargs()
3243
            input_ids = None
3244
        else:
3245
3246
3247
3248
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
3249
            input_ids = self.input_ids.gpu[:num_input_tokens]
3250
            inputs_embeds = None
3251
            model_kwargs = self._init_model_kwargs()
3252

guanyu1's avatar
guanyu1 committed
3253
        positions = self._get_positions(num_input_tokens)       
3254

3255
        if is_first_rank:
3256
3257
            intermediate_tensors = None
        else:
3258
            assert intermediate_tensors is not None
3259
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
3260
3261
                num_input_tokens, intermediate_tensors, True
            )
3262

3263
        if is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
3264
3265
3266
3267
3268
3269
3270
            # Run the encoder, just like we do with other multimodal inputs.
            # For an encoder-decoder model, our processing here is a bit
            # simpler, because the outputs are just passed to the decoder.
            # We are not doing any prompt replacement. We also will only
            # ever have a single encoder input.
            encoder_outputs = self._execute_mm_encoder(scheduler_output)
            model_kwargs.update({"encoder_outputs": encoder_outputs})
3271

3272
3273
3274
3275
3276
3277
        return (
            input_ids,
            inputs_embeds,
            positions,
            intermediate_tensors,
            model_kwargs,
3278
            ec_connector_output,
3279
        )
3280

3281
    def _sample(
3282
        self,
3283
3284
        logits: torch.Tensor | None,
        spec_decode_metadata: SpecDecodeMetadata | None,
3285
    ) -> SamplerOutput:
3286
        # Sample the next token and get logprobs if needed.
3287
        sampling_metadata = self.input_batch.sampling_metadata
3288
3289
3290
        # Update output token ids with tokens sampled in last step
        # if async scheduling and required by current sampling params.
        self.input_batch.update_async_output_token_ids()
3291
        if spec_decode_metadata is None:
3292
            return self.sampler(
3293
3294
3295
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
3296

3297
3298
3299
3300
3301
3302
        # Update spec_token_ids with real draft tokens from pre step only when
        # output_token_ids is needed (penalties or bad_words are in use).
        if self.use_async_scheduling and self._draft_token_req_ids is not None:
            draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu()
            self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)

3303
        sampler_output = self.rejection_sampler(
3304
            spec_decode_metadata,
王敏's avatar
王敏 committed
3305
3306
            None if self.draft_probs is None else \
                self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids),  # draft_probs
3307
            logits,
3308
3309
            sampling_metadata,
        )
3310
3311
3312
        return sampler_output

    def _bookkeeping_sync(
3313
3314
3315
        self,
        scheduler_output: "SchedulerOutput",
        sampler_output: SamplerOutput,
3316
        logits: torch.Tensor | None,
3317
3318
        hidden_states: torch.Tensor,
        num_scheduled_tokens: int,
3319
        spec_decode_metadata: SpecDecodeMetadata | None,
3320
    ) -> tuple[
3321
        dict[str, int],
3322
        LogprobsLists | None,
3323
        list[list[int]],
3324
        dict[str, LogprobsTensors | None],
3325
3326
3327
        list[str],
        dict[str, int],
        list[int],
3328
    ]:
3329
3330
3331
3332
        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

3333
3334
3335
3336
        num_reqs = self.input_batch.num_reqs
        discard_sampled_tokens_req_indices = np.nonzero(
            self.discard_request_mask.np[:num_reqs]
        )[0]
3337
3338
3339
3340
        for i in discard_sampled_tokens_req_indices:
            gen = self.input_batch.generators.get(int(i))
            if gen is not None:
                gen.set_offset(gen.get_offset() - 4)
3341

3342
3343
3344
        # Copy some objects so they don't get modified after returning.
        # This is important when using async scheduling.
        req_ids_output_copy = self.input_batch.req_ids.copy()
3345
        req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
3346

3347
        num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
3348
        sampled_token_ids = sampler_output.sampled_token_ids
3349
        logprobs_tensors = sampler_output.logprobs_tensors
3350
        invalid_req_indices = []
3351
        logprobs_lists = None
3352
3353
3354
3355
3356
3357
        if not self.use_async_scheduling:
            # Get the valid generated tokens.
            max_gen_len = sampled_token_ids.shape[-1]
            if max_gen_len == 1:
                # No spec decode tokens.
                valid_sampled_token_ids = self._to_list(sampled_token_ids)
3358
3359
3360
                # Mask out the sampled tokens that should not be sampled.
                for i in discard_sampled_tokens_req_indices:
                    valid_sampled_token_ids[int(i)].clear()
3361
3362
3363

                if logprobs_tensors is not None:
                    logprobs_lists = logprobs_tensors.tolists()
3364
3365
            else:
                # Includes spec decode tokens.
3366
                valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
3367
3368
                    sampled_token_ids,
                    self.input_batch.vocab_size,
3369
                    discard_sampled_tokens_req_indices,
3370
                    logprobs_tensors=logprobs_tensors,
3371
                )
3372
        else:
3373
            valid_sampled_token_ids = []
3374
            invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
3375
3376
3377
3378
3379
            invalid_req_indices_set = set(invalid_req_indices)

            # Cache the sampled tokens on the GPU and avoid CPU sync.
            # These will be copied into input_ids in the next step
            # when preparing inputs.
3380
3381
3382
3383
            # With spec decoding, this is done in propose_draft_token_ids().
            if self.input_batch.prev_sampled_token_ids is None:
                assert sampled_token_ids.shape[-1] == 1
                self.input_batch.prev_sampled_token_ids = sampled_token_ids
3384
3385
3386
3387
3388
            self.input_batch.prev_req_id_to_index = {
                req_id: i
                for i, req_id in enumerate(self.input_batch.req_ids)
                if i not in invalid_req_indices_set
            }
3389

3390
3391
3392
3393
3394
        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
3395
        req_ids = self.input_batch.req_ids
3396
3397
        for req_idx in range(num_sampled_tokens):
            if self.use_async_scheduling:
3398
                sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
3399
3400
            else:
                sampled_ids = valid_sampled_token_ids[req_idx]
3401

3402
            num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
3403

3404
3405
3406
3407
            if not sampled_ids:
                continue

            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
3408
            end_idx = start_idx + num_sampled_ids
3409
3410
3411
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
3412
                f"{self.max_model_len}"
3413
            )
3414

3415
            self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
3416
            self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
3417
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
3418

3419
            req_id = req_ids[req_idx]
3420
3421
3422
            req_state = self.requests[req_id]
            req_state.output_token_ids.extend(sampled_ids)

3423
3424
3425
3426
3427
3428
        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states[:num_scheduled_tokens],
            scheduler_output.num_scheduled_tokens,
        )

3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
        return (
            num_nans_in_logits,
            logprobs_lists,
            valid_sampled_token_ids,
            prompt_logprobs_dict,
            req_ids_output_copy,
            req_id_to_index_output_copy,
            invalid_req_indices,
        )

3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
    @contextmanager
    def synchronize_input_prep(self):
        if self.prepare_inputs_event is None:
            yield
            return

        # Ensure prior step has finished with reused CPU tensors.
        # This is required in the async scheduling case because
        # the CPU->GPU transfer happens async.
        self.prepare_inputs_event.synchronize()
        try:
            yield
        finally:
            self.prepare_inputs_event.record()

3454
3455
    def _model_forward(
        self,
3456
3457
3458
3459
        input_ids: torch.Tensor | None = None,
        positions: torch.Tensor | None = None,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
3460
3461
3462
3463
3464
        **model_kwargs: dict[str, Any],
    ) -> Any:
        """Helper method to call the model forward pass.

        This method can be overridden by subclasses for model execution.
3465
        Motivation: We can inspect only this method versus
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
        the whole execute_model, which has additional logic.

        Args:
            input_ids: Input token IDs
            positions: Token positions
            intermediate_tensors: Tensors from previous pipeline stages
            inputs_embeds: Input embeddings (alternative to input_ids)
            **model_kwargs: Additional model arguments

        Returns:
            Model output tensor
        """
        return self.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **model_kwargs,
        )

3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
    @staticmethod
    def _is_uniform_decode(
        max_num_scheduled_tokens: int,
        uniform_decode_query_len: int,
        num_tokens: int,
        num_reqs: int,
        force_uniform_decode: bool | None = None,
    ) -> bool:
        """
        Checks if it's a decode batch with same amount scheduled tokens
        across all requests.
        """
        return (
            (
                (max_num_scheduled_tokens == uniform_decode_query_len)
                and (num_tokens == max_num_scheduled_tokens * num_reqs)
            )
            if force_uniform_decode is None
            else force_uniform_decode
        )

3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
    def _determine_batch_execution_and_padding(
        self,
        num_tokens: int,
        num_reqs: int,
        num_scheduled_tokens_np: np.ndarray,
        max_num_scheduled_tokens: int,
        use_cascade_attn: bool,
        allow_microbatching: bool = True,
        force_eager: bool = False,
        # For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will
        # be improved in model runner v2)
        force_uniform_decode: bool | None = None,
        force_has_lora: bool | None = None,
3520
        num_encoder_reqs: int = 0,
3521
    ) -> tuple[
3522
3523
        CUDAGraphMode,
        BatchDescriptor,
3524
        bool,
3525
3526
        torch.Tensor | None,
        CUDAGraphStat | None,
3527
    ]:
3528
3529
3530
3531
3532
3533
        uniform_decode = self._is_uniform_decode(
            max_num_scheduled_tokens=max_num_scheduled_tokens,
            uniform_decode_query_len=self.uniform_decode_query_len,
            num_tokens=num_tokens,
            num_reqs=num_reqs,
            force_uniform_decode=force_uniform_decode,
3534
        )
3535
3536
3537
3538
3539
        # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output
        # is present). Also, chunked-prefill is disabled, so batch are uniform.
        has_encoder_output = (
            self.model_config.is_encoder_decoder and num_encoder_reqs > 0
        )
3540
3541
3542
3543
3544
3545
3546

        has_lora = (
            len(self.input_batch.lora_id_to_lora_request) > 0
            if force_has_lora is None
            else force_has_lora
        )

3547
        num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
3548
        dispatch_cudagraph = (
3549
            lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
3550
3551
3552
                num_tokens=num_tokens,
                has_lora=has_lora,
                uniform_decode=uniform_decode,
3553
                disable_full=disable_full,
3554
3555
3556
3557
3558
            )
            if not force_eager
            else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
        )

3559
        cudagraph_mode, batch_descriptor = dispatch_cudagraph(
3560
            num_tokens_padded, use_cascade_attn or has_encoder_output
3561
        )
3562
        num_tokens_padded = batch_descriptor.num_tokens
3563
3564
3565
3566
3567
3568
3569
3570
3571
        if self.compilation_config.pass_config.enable_sp:
            assert (
                batch_descriptor.num_tokens
                % self.vllm_config.parallel_config.tensor_parallel_size
                == 0
            ), (
                "Sequence parallelism requires num_tokens to be "
                "a multiple of tensor parallel size"
            )
3572
3573
3574

        # Extra coordination when running data-parallel since we need to coordinate
        # across ranks
3575
        should_ubatch, num_tokens_across_dp = False, None
3576
3577
3578
3579
3580
3581
3582
3583
3584
        if self.vllm_config.parallel_config.data_parallel_size > 1:
            # Disable DP padding when running eager to avoid excessive padding when
            # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
            # in a P/D setup and still use CUDA graphs (enabled by this padding) on the
            # decoder.
            allow_dp_padding = (
                self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            )

3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
            should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
                coordinate_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    parallel_config=self.parallel_config,
                    allow_microbatching=allow_microbatching,
                    allow_dp_padding=allow_dp_padding,
                    num_tokens_padded=num_tokens_padded,
                    uniform_decode=uniform_decode,
                    num_scheduled_tokens_per_request=num_scheduled_tokens_np,
                    cudagraph_mode=cudagraph_mode.value,
                )
3596
3597
            )

3598
            # Extract DP-synced values
3599
3600
3601
            if num_tokens_across_dp is not None:
                dp_rank = self.parallel_config.data_parallel_rank
                num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
3602
3603
3604
3605
3606
                # Re-dispatch with DP padding so we have the correct batch_descriptor
                cudagraph_mode, batch_descriptor = dispatch_cudagraph(
                    num_tokens_padded,
                    disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
                )
3607
3608
3609
3610
                # Assert to make sure the agreed upon token count is correct otherwise
                # num_tokens_across_dp will no-longer be valid
                assert batch_descriptor.num_tokens == num_tokens_padded

3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
        cudagraph_stats = None
        if self.vllm_config.observability_config.cudagraph_metrics:
            cudagraph_stats = CUDAGraphStat(
                num_unpadded_tokens=num_tokens,
                num_padded_tokens=batch_descriptor.num_tokens,
                num_paddings=batch_descriptor.num_tokens - num_tokens,
                runtime_mode=str(cudagraph_mode),
            )

        return (
            cudagraph_mode,
            batch_descriptor,
3623
            should_ubatch,
3624
3625
3626
            num_tokens_across_dp,
            cudagraph_stats,
        )
3627

3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
    def _register_layerwise_nvtx_hooks(self) -> None:
        """
        Register layerwise NVTX hooks if --enable-layerwise-nvtx-tracing is enabled
        to trace detailed information of each layer or module in the model.
        """

        if (
            self.vllm_config.observability_config.enable_layerwise_nvtx_tracing
            and not self.layerwise_nvtx_hooks_registered
        ):
            if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
                logger.debug_once(
                    "layerwise NVTX tracing is not supported when CUDA graph is "
                    "turned off; you may observe part or all of the model "
                    "missing NVTX markers"
                )

            # In STOCK_TORCH_COMPILE mode, after registering hooks here,
            # the __call__ function of nn.module will be recompiled with
            # fullgraph=True. Since nvtx.range_push/pop are not traceable
            # by torch dynamo, we can't register hook functions here
            # because hook functions will also be traced by torch dynamo.
            if (
                self.vllm_config.compilation_config.mode
                == CompilationMode.STOCK_TORCH_COMPILE
            ):
                logger.debug_once(
                    "layerwise NVTX tracing is not supported when "
                    "CompilationMode is STOCK_TORCH_COMPILE, skipping "
                    "function hooks registration"
                )
            else:
                pyt_hooks = PytHooks()
                pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
                self.layerwise_nvtx_hooks_registered = True
3663

3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
    def _get_slot_mappings(
        self,
        num_tokens_padded: int,
        num_reqs_padded: int,
        num_tokens_unpadded: int,
        ubatch_slices: "UBatchSlices | None" = None,
    ) -> tuple[
        dict[int, torch.Tensor] | None,
        dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
    ]:
        """
        Build slot mappings in both formats needed by the system.

        Args:
            num_tokens_padded: Total number of tokens (padded)
            num_reqs_padded: Total number of requests (padded)
            num_tokens_unpadded: Actual number of tokens (unpadded)
            ubatch_slices: Optional ubatch slicing info for DBO

        Returns:
            A tuple of:
            - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata
            - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext
        """
        if not (
            hasattr(self, "kv_cache_config")
            and self.kv_cache_config is not None
            and len(self.kv_cache_config.kv_cache_groups) > 0
        ):
            return None, None

        def _get_slot_mapping(kv_cache_gid: int):
            assert num_reqs_padded is not None and num_tokens_padded is not None
            kv_cache_spec = self.kv_cache_config.kv_cache_groups[
                kv_cache_gid
            ].kv_cache_spec
            if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
                slot_mapping = torch.zeros(
                    (num_tokens_padded,),
                    dtype=torch.int64,
                    device=self.device,
                )
            else:
                blk_table = self.input_batch.block_table[kv_cache_gid]
                slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]

            # Fill unused with -1. Needed for reshape_and_cache in full cuda
            # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
            slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1)

            return slot_mapping

        slot_mappings_by_gid = {
            gid: _get_slot_mapping(gid)
            for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups)
        }

        slot_mappings_by_layer: dict[str, torch.Tensor] = {}
        for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups):
            slot_mapping = slot_mappings_by_gid[gid]
            for layer_name in kv_cache_group.layer_names:
                slot_mappings_by_layer[layer_name] = slot_mapping

        if ubatch_slices is not None:
            result: list[dict[str, torch.Tensor]] = []
            for ubatch in ubatch_slices:
                sliced_mappings: dict[str, torch.Tensor] = {}
                for layer_name, slot_mapping in slot_mappings_by_layer.items():
                    sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice]
                result.append(sliced_mappings)
            return slot_mappings_by_gid, result

        return slot_mappings_by_gid, slot_mappings_by_layer

3738
3739
3740
3741
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
3742
        intermediate_tensors: IntermediateTensors | None = None,
3743
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None:
3744
3745
3746
3747
3748
        if self.execute_model_state is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
3749

3750
3751
3752
3753
3754
3755
        if self.vllm_config.model_config.enable_return_routed_experts:
            capturer = RoutedExpertsCapturer.get_instance()
            if capturer is not None:
                capturer.clear_buffer()  # noqa
            else:
                logger.error("RoutedExpertsCapturer not initialized.")
3756

3757
3758
3759
3760
        if scheduler_output.preempted_req_ids and has_kv_transfer_group():
            get_kv_transfer_group().handle_preemptions(
                scheduler_output.preempted_req_ids
            )
3761

3762
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
3763
3764
3765
3766
3767
3768
        with (
            record_function_or_nullcontext("gpu_model_runner: preprocess"),
            self.synchronize_input_prep(),
        ):
            # Update persistent batch states.
            self._update_states(scheduler_output)
3769

3770
3771
            if has_ec_transfer() and get_ec_transfer().is_producer:
                with self.maybe_get_ec_connector_output(
3772
                    scheduler_output,
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
                    encoder_cache=self.encoder_cache,
                ) as ec_connector_output:
                    self._execute_mm_encoder(scheduler_output)
                    return make_empty_encoder_model_runner_output(scheduler_output)

            if not num_scheduled_tokens:
                if (
                    self.parallel_config.distributed_executor_backend
                    == "external_launcher"
                    and self.parallel_config.data_parallel_size > 1
                ):
                    # this is a corner case when both external launcher
                    # and DP are enabled, num_scheduled_tokens could be
                    # 0, and has_unfinished_requests in the outer loop
                    # returns True. before returning early here we call
                    # dummy run to ensure coordinate_batch_across_dp
                    # is called into to avoid out of sync issues.
                    self._dummy_run(1)
                if not has_kv_transfer_group():
                    # Return empty ModelRunnerOutput if no work to do.
                    return EMPTY_MODEL_RUNNER_OUTPUT
                return self.kv_connector_no_forward(scheduler_output, self.vllm_config)

            if self.cache_config.kv_sharing_fast_prefill:
                assert not self.num_prompt_logprobs, (
                    "--kv-sharing-fast-prefill produces incorrect "
                    "logprobs for prompt tokens, tokens, please disable "
                    "it when the requests need prompt logprobs"
3801
                )
3802

3803
3804
3805
3806
3807
3808
            num_reqs = self.input_batch.num_reqs
            req_ids = self.input_batch.req_ids
            tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
            num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
            max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
            num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
3809

3810
3811
3812
3813
            logits_indices, spec_decode_metadata = self._prepare_inputs(
                scheduler_output,
                num_scheduled_tokens_np,
            )
3814

3815
3816
3817
3818
3819
            cascade_attn_prefix_lens = None
            # Disable cascade attention when using microbatching (DBO)
            if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
                # Pre-compute cascade attention prefix lengths
                cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
3820
                    num_scheduled_tokens_np,
3821
3822
                    self.input_batch.num_computed_tokens_cpu[:num_reqs],
                    scheduler_output.num_common_prefix_blocks,
3823
                )
3824

3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
            (
                cudagraph_mode,
                batch_desc,
                should_ubatch,
                num_tokens_across_dp,
                cudagraph_stats,
            ) = self._determine_batch_execution_and_padding(
                num_tokens=num_tokens_unpadded,
                num_reqs=num_reqs,
                num_scheduled_tokens_np=num_scheduled_tokens_np,
                max_num_scheduled_tokens=max_num_scheduled_tokens,
                use_cascade_attn=cascade_attn_prefix_lens is not None,
                num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
            )
3839

3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
            logger.debug(
                "Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
                "should_ubatch: %s, num_tokens_across_dp: %s",
                cudagraph_mode,
                batch_desc,
                should_ubatch,
                num_tokens_across_dp,
            )

            num_tokens_padded = batch_desc.num_tokens
3850
3851
            if self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould:
                num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
3852
3853
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
            num_reqs_padded = (
                batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
            )
            ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
                should_ubatch,
                num_scheduled_tokens_np,
                num_tokens_padded,
                num_reqs_padded,
                self.parallel_config.num_ubatches,
            )

            logger.debug(
                "ubatch_slices: %s, ubatch_slices_padded: %s",
                ubatch_slices,
                ubatch_slices_padded,
            )

3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
            # True if any attention backend handles KV cache update separately
            # from forward() (i.e., forward_includes_kv_cache_update=False). When true,
            # slot_mappings must use padded dimensions to match the key/value tensors.
            has_separate_kv_update = not all(
                all(
                    g.backend.forward_includes_kv_cache_update
                    for g in self.attn_groups[id]
                )
                for id, spec in enumerate(self.kv_cache_config.kv_cache_groups)
                if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec)
            )
3880
3881
            pad_attn = cudagraph_mode == CUDAGraphMode.FULL

3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
            if self.cache_config.mamba_cache_mode == "align":
                mamba_utils.preprocess_mamba(
                    scheduler_output,
                    self.kv_cache_config,
                    self.cache_config,
                    self.mamba_state_idx,
                    self.input_batch,
                    self.requests,
                    self.compilation_config.static_forward_context,
                    self.model.get_mamba_state_copy_func(),
                )

3894
3895
3896
            use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
            ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices

3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
            slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
                num_tokens_padded=num_tokens_padded
                if pad_attn or has_separate_kv_update
                else num_tokens_unpadded,
                num_reqs_padded=(
                    num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs
                ),
                num_tokens_unpadded=num_tokens_unpadded,
                ubatch_slices=ubatch_slices_padded,
            )

3908
3909
3910
3911
3912
3913
            (
                attn_metadata,
                spec_decode_common_attn_metadata,
                scatter_indexes_tensor,
                gather_indexes_tensor,
            ) = self._build_attention_metadata(
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
                    num_tokens=num_tokens_unpadded,
                    num_tokens_padded=num_tokens_padded if pad_attn else None,
                    num_reqs=num_reqs,
                    num_reqs_padded=num_reqs_padded if pad_attn else None,
                    max_query_len=max_num_scheduled_tokens,
                    ubatch_slices=ubatch_slices_attn,
                    logits_indices=logits_indices,
                    use_spec_decode=use_spec_decode,
                    num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
                    cascade_attn_prefix_lens=cascade_attn_prefix_lens,
3924
                    slot_mappings=slot_mappings_by_group,
3925
                )
3926
3927
3928
3929
3930
3931
3932

            (
                input_ids,
                inputs_embeds,
                positions,
                intermediate_tensors,
                model_kwargs,
3933
3934
3935
                ec_connector_output,
            ) = self._preprocess(
                scheduler_output, num_tokens_padded, intermediate_tensors
3936
            )
3937

3938
        # Set cudagraph mode to none if calc_kv_scales is true.
3939
3940
3941
        # KV scales calculation involves dynamic operations that are incompatible
        # with CUDA graph capture.
        if self.calculate_kv_scales:
3942
            cudagraph_mode = CUDAGraphMode.NONE
3943
3944
            # Mark KV scales as calculated after the first forward pass
            self.calculate_kv_scales = False
3945

3946
3947
3948
3949
3950
3951
3952
        # Encoder-decoder models can only compile the pure decode steps where no
        # encoder inputs are present. Use eager for the first pass.
        num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs)
        has_encoder_input = (
            self.model_config.is_encoder_decoder and num_encoder_reqs > 0
        )

3953
3954
        # Run the model.
        # Use persistent buffers for CUDA graphs.
3955
        clear_kv_metadata = self.speculative_config is None
3956
3957
        with (
            set_forward_context(
3958
3959
                attn_metadata,
                self.vllm_config,
3960
                num_tokens=num_tokens_padded,
3961
                num_tokens_across_dp=num_tokens_across_dp,
3962
3963
                cudagraph_runtime_mode=cudagraph_mode,
                batch_descriptor=batch_desc,
3964
                ubatch_slices=ubatch_slices_padded,
3965
                slot_mapping=slot_mappings,
3966
                skip_compiled=has_encoder_input,
3967
3968
3969
3970
                scatter_indexes_tensor=scatter_indexes_tensor,
                gather_indexes_tensor=gather_indexes_tensor,
                enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
                enable_lightly_cplb=self.enable_lightly_cplb
3971
            ),
3972
            record_function_or_nullcontext("gpu_model_runner: forward"),
3973
3974
3975
            self.maybe_get_kv_connector_output(
                scheduler_output, clear_metadata=clear_kv_metadata
            ) as kv_connector_output,
3976
        ):
3977
            model_output = self._model_forward(
3978
3979
3980
3981
3982
3983
3984
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                **model_kwargs,
            )

3985
        with record_function_or_nullcontext("gpu_model_runner: postprocess"):
3986
            if self.use_aux_hidden_state_outputs:
3987
                # True when EAGLE 3 is used.
3988
3989
                hidden_states, aux_hidden_states = model_output
            else:
3990
                # Common case.
3991
3992
3993
                hidden_states = model_output
                aux_hidden_states = None

3994
3995
3996
3997
3998
            if not self.broadcast_pp_output:
                # Common case.
                if not get_pp_group().is_last_rank:
                    # Return the intermediate tensors.
                    assert isinstance(hidden_states, IntermediateTensors)
3999
                    hidden_states.kv_connector_output = kv_connector_output
4000
                    self.kv_connector_output = kv_connector_output
4001
                    return hidden_states
4002

4003
                if self.is_pooling_model:
4004
                    # Return the pooling output.
4005
4006
4007
4008
4009
                    return self._pool(
                        hidden_states,
                        num_scheduled_tokens,
                        num_scheduled_tokens_np,
                        kv_connector_output,
4010
                    )
4011
4012

                sample_hidden_states = hidden_states[logits_indices]
4013
                logits = self.model.compute_logits(sample_hidden_states)
4014
4015
4016
4017
            else:
                # Rare case.
                assert not self.is_pooling_model

4018
                sample_hidden_states = hidden_states[logits_indices]
4019
                if not get_pp_group().is_last_rank:
4020
                    all_gather_tensors = {
4021
                        "residual": not is_residual_scattered_for_sp(
4022
                            self.vllm_config, num_tokens_padded
4023
                        )
4024
                    }
4025
                    get_pp_group().send_tensor_dict(
4026
4027
                        hidden_states.tensors,
                        all_gather_group=get_tp_group(),
4028
4029
                        all_gather_tensors=all_gather_tensors,
                    )
4030
4031
                    logits = None
                else:
4032
                    logits = self.model.compute_logits(sample_hidden_states)
4033

4034
                model_output_broadcast_data: dict[str, Any] = {}
4035
4036
4037
                if logits is not None:
                    model_output_broadcast_data["logits"] = logits.contiguous()

4038
                broadcasted = get_pp_group().broadcast_tensor_dict(
4039
4040
                    model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
                )
4041
4042
                assert broadcasted is not None
                logits = broadcasted["logits"]
4043

4044
4045
4046
4047
4048
4049
4050
4051
        self.execute_model_state = ExecuteModelState(
            scheduler_output,
            logits,
            spec_decode_metadata,
            spec_decode_common_attn_metadata,
            hidden_states,
            sample_hidden_states,
            aux_hidden_states,
4052
            ec_connector_output,
4053
            cudagraph_stats,
4054
            slot_mappings,
4055
        )
4056
        self.kv_connector_output = kv_connector_output
4057
4058
4059
4060
4061
4062
        return None

    @torch.inference_mode
    def sample_tokens(
        self, grammar_output: "GrammarOutput | None"
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
4063
4064
4065
        kv_connector_output = self.kv_connector_output
        self.kv_connector_output = None

4066
        if self.execute_model_state is None:
王敏's avatar
王敏 committed
4067
            # Nothing to do (PP non-final rank case), output isn't used.
4068
            if not kv_connector_output:
4069
                return None  # type: ignore[return-value]
4070
4071
4072
4073
4074
4075
4076
4077
4078

            # In case of PP with kv transfer, we need to pass through the
            # kv_connector_output
            if kv_connector_output.is_empty():
                return EMPTY_MODEL_RUNNER_OUTPUT

            output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
            output.kv_connector_output = kv_connector_output
            return output
4079

4080
4081
4082
4083
4084
4085
4086
4087
4088
        # Unpack ephemeral state.
        (
            scheduler_output,
            logits,
            spec_decode_metadata,
            spec_decode_common_attn_metadata,
            hidden_states,
            sample_hidden_states,
            aux_hidden_states,
4089
            ec_connector_output,
4090
            cudagraph_stats,
4091
            slot_mappings,
4092
4093
4094
4095
4096
4097
4098
4099
4100
        ) = self.execute_model_state
        # Clear ephemeral state.
        self.execute_model_state = None

        # Apply structured output bitmasks if present.
        if grammar_output is not None:
            apply_grammar_bitmask(
                scheduler_output, grammar_output, self.input_batch, logits
            )
4101

4102
        with record_function_or_nullcontext("gpu_model_runner: sample"):
4103
4104
            sampler_output = self._sample(logits, spec_decode_metadata)

4105
4106
4107
4108
        self._update_states_after_model_execute(
            sampler_output.sampled_token_ids, scheduler_output
        )

4109
4110
        self._draft_token_ids = None
        self._draft_token_req_ids = None
4111
4112
        self.input_batch.prev_sampled_token_ids = None

4113
4114
        def propose_draft_token_ids(sampled_token_ids):
            assert spec_decode_common_attn_metadata is not None
4115
            with record_function_or_nullcontext("gpu_model_runner: draft"):
4116
4117
4118
4119
4120
4121
4122
4123
4124
                self._draft_token_ids = self.propose_draft_token_ids(
                    scheduler_output,
                    sampled_token_ids,
                    self.input_batch.sampling_metadata,
                    hidden_states,
                    sample_hidden_states,
                    aux_hidden_states,
                    spec_decode_metadata,
                    spec_decode_common_attn_metadata,
4125
                    slot_mappings,
4126
                )
4127
                self._copy_draft_token_ids_to_cpu(scheduler_output)
4128

4129
        spec_config = self.speculative_config
4130
4131
4132
4133
4134
        propose_drafts_after_bookkeeping = False
        if spec_config is not None:
            input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
                spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
                <= self.effective_drafter_max_model_len
4135
            )
4136
4137
4138
4139
4140
            use_gpu_toks = (
                spec_config.use_eagle() or spec_config.uses_draft_model()
            ) and not spec_config.disable_padded_drafter_batch
            if use_gpu_toks:
                # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
4141
                # as inputs, and does not need to wait for bookkeeping to finish.
4142
                assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
4143
4144
4145
4146
4147
4148
4149
4150
4151
4152
4153
4154
4155
                sampled_token_ids = sampler_output.sampled_token_ids
                if input_fits_in_drafter:
                    propose_draft_token_ids(sampled_token_ids)
                elif self.valid_sampled_token_count_event is not None:
                    assert spec_decode_common_attn_metadata is not None
                    next_token_ids, valid_sampled_tokens_count = (
                        self.drafter.prepare_next_token_ids_padded(
                            spec_decode_common_attn_metadata,
                            sampled_token_ids,
                            self.requests,
                            self.input_batch,
                            self.discard_request_mask.gpu,
                        )
4156
                    )
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
4167
                    self._copy_valid_sampled_token_count(
                        next_token_ids, valid_sampled_tokens_count
                    )
                    # Since we couldn't run the drafter,
                    # just use zeros for the draft tokens.
                    self._draft_token_ids = torch.zeros(
                        1, device=self.device, dtype=torch.int32
                    ).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
                    self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
            else:
                propose_drafts_after_bookkeeping = input_fits_in_drafter
4168

4169
        with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
4170
4171
4172
4173
4174
4175
4176
4177
            (
                num_nans_in_logits,
                logprobs_lists,
                valid_sampled_token_ids,
                prompt_logprobs_dict,
                req_ids_output_copy,
                req_id_to_index_output_copy,
                invalid_req_indices,
4178
4179
4180
4181
4182
            ) = self._bookkeeping_sync(
                scheduler_output,
                sampler_output,
                logits,
                hidden_states,
4183
                scheduler_output.total_num_scheduled_tokens,
4184
                spec_decode_metadata,
4185
            )
4186

4187
        if propose_drafts_after_bookkeeping:
4188
4189
4190
            # ngram and other speculative decoding methods use the sampled
            # tokens on the CPU, so they are run after bookkeeping.
            propose_draft_token_ids(valid_sampled_token_ids)
4191
4192
4193
            
        if self.speculative_config is not None:
            self.clear_kv_connector_metadata()
4194

4195
        with record_function_or_nullcontext("gpu_model_runner: eplb"):
4196
            self.eplb_step()
4197

4198
        with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
4199
4200
            # Get draft token ids if available
            output_spec_token_ids = None
王敏's avatar
王敏 committed
4201
            if not self.use_async_scheduling and self._draft_token_ids is not None:
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
                # Use synchronous copy to avoid NPU async stream/event
                # synchronization issues. _get_draft_token_ids_cpu relies on
                # event.synchronize() which may not properly wait for the
                # async copy on NPU, resulting in stale data.
                if torch.is_tensor(self._draft_token_ids):
                    num_reqs = self._draft_token_ids.shape[0]
                    draft_ids_list = self._draft_token_ids[:num_reqs].cpu().tolist()
                    draft_req_ids = self._draft_token_req_ids
                else:
                    draft_ids_list = self._draft_token_ids
                    draft_req_ids = self.input_batch.req_ids
                if draft_ids_list and draft_req_ids:
                    draft_by_req_id = dict(
                        zip(draft_req_ids, draft_ids_list))
                    output_spec_token_ids = [
                        draft_by_req_id.get(req_id, [])
                        for req_id in req_ids_output_copy
                    ]

4221
4222
4223
4224
4225
4226
4227
            if self.model_config.enable_return_routed_experts:
                capturer = RoutedExpertsCapturer.get_instance()
                if capturer is not None:
                    capturer.save_captured_experts(indices=self.slot_mapping)  # noqa
                else:
                    logger.error("RoutedExpertsCapturer not initialized.")

4228
4229
4230
4231
            output = ModelRunnerOutput(
                req_ids=req_ids_output_copy,
                req_id_to_index=req_id_to_index_output_copy,
                sampled_token_ids=valid_sampled_token_ids,
4232
                spec_token_ids=output_spec_token_ids,
4233
4234
4235
                logprobs=logprobs_lists,
                prompt_logprobs_dict=prompt_logprobs_dict,
                kv_connector_output=kv_connector_output,
4236
4237
4238
                ec_connector_output=ec_connector_output
                if self.supports_mm_inputs
                else None,
4239
                num_nans_in_logits=num_nans_in_logits,
4240
                cudagraph_stats=cudagraph_stats,
4241
            )
4242

4243
4244
        if not self.use_async_scheduling:
            return output
4245

4246
4247
4248
4249
4250
4251
4252
4253
4254
        with record_function_or_nullcontext(
            "gpu_model_runner: AsyncGPUModelRunnerOutput"
        ):
            async_output = AsyncGPUModelRunnerOutput(
                model_runner_output=output,
                sampled_token_ids=sampler_output.sampled_token_ids,
                logprobs_tensors=sampler_output.logprobs_tensors,
                invalid_req_indices=invalid_req_indices,
                async_output_copy_stream=self.async_output_copy_stream,
4255
                vocab_size=self.input_batch.vocab_size,
4256
4257
4258
4259
4260
            )
        with record_function_or_nullcontext(
            "gpu_model_runner: set_async_sampled_token_ids"
        ):
            # Save ref of sampled_token_ids CPU tensor if the batch contains
4261
            # any requests with sampling params that require output ids.
4262
4263
4264
4265
            self.input_batch.set_async_sampled_token_ids(
                async_output.sampled_token_ids_cpu,
                async_output.async_copy_ready_event,
            )
4266

4267
        return async_output
4268

4269
    def take_draft_token_ids(self) -> DraftTokenIds | None:
4270
        if not self.num_spec_tokens or not self._draft_token_req_ids:
4271
            return None
4272
        draft_token_ids, req_ids = self._get_draft_token_ids_cpu()
4273
4274
        return DraftTokenIds(req_ids, draft_token_ids)

4275
4276
4277
    def _copy_draft_token_ids_to_cpu(
        self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
    ) -> None:
4278
4279
4280
4281
4282
4283
        # Check if we need to copy draft tokens to CPU. In async scheduling,
        # we only copy when needed for structured output, penalties or bad_words.
        if self.use_async_scheduling and not (
            scheduler_output.has_structured_output_requests
            or self.input_batch.sampling_metadata.output_token_ids
        ):
4284
4285
4286
            return
        # We must also set the corresponding request ids.
        self._draft_token_req_ids = self.input_batch.req_ids.copy()
4287

4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301
4302
4303
4304
4305
4306
4307
        draft_token_ids: torch.Tensor = self._draft_token_ids
        if not torch.is_tensor(draft_token_ids):
            return
        assert self.draft_token_ids_event is not None
        assert self.draft_token_ids_copy_stream is not None
        assert self.draft_token_ids_cpu is not None
        default_stream = torch.cuda.current_stream()
        num_reqs = draft_token_ids.shape[0]
        with torch.cuda.stream(self.draft_token_ids_copy_stream):
            if not zeros_only:
                # Trigger async copy of draft token ids to cpu.
                self.draft_token_ids_copy_stream.wait_stream(default_stream)
                self.draft_token_ids_cpu[:num_reqs].copy_(
                    draft_token_ids, non_blocking=True
                )
            else:
                # No copy needed, just zero-out cpu tensor.
                self.draft_token_ids_cpu[:num_reqs] = 0
            self.draft_token_ids_event.record()

4308
    def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]:
4309
        if isinstance(self._draft_token_ids, list):
4310
4311
4312
4313
            return self._draft_token_ids, self.input_batch.req_ids
        req_ids = self._draft_token_req_ids
        if req_ids is None:
            return [], []
4314
4315
4316
        assert self.draft_token_ids_event is not None
        assert self.draft_token_ids_cpu is not None
        self.draft_token_ids_event.synchronize()
4317
        return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids
4318

4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
    def _copy_valid_sampled_token_count(
        self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
    ) -> None:
        if self.valid_sampled_token_count_event is None:
            return

        default_stream = torch.cuda.current_stream()
        # Initialize a new stream to overlap the copy operation with
        # prepare_input of draft model.
        with torch.cuda.stream(self.valid_sampled_token_count_copy_stream):
            self.valid_sampled_token_count_copy_stream.wait_stream(default_stream)  # type: ignore
            counts = valid_sampled_tokens_count
            counts_cpu = self.valid_sampled_token_count_cpu
4332
            assert counts_cpu is not None
4333
4334
4335
4336
4337
4338
4339
4340
            counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
            self.valid_sampled_token_count_event.record()

        self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)

    def _get_valid_sampled_token_count(self) -> list[int]:
        # Wait until valid_sampled_tokens_count is copied to cpu,
        prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
4341
4342
        sampled_count_event = self.valid_sampled_token_count_event
        if sampled_count_event is None or prev_sampled_token_ids is None:
4343
4344
4345
            return []

        counts_cpu = self.valid_sampled_token_count_cpu
4346
4347
        assert counts_cpu is not None
        sampled_count_event.synchronize()
4348
4349
        return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()

4350
4351
4352
    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
4353
        sampled_token_ids: torch.Tensor | list[list[int]],
4354
4355
4356
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
4357
4358
        aux_hidden_states: list[torch.Tensor] | None,
        spec_decode_metadata: SpecDecodeMetadata | None,
4359
        common_attn_metadata: CommonAttentionMetadata,
4360
        slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
4361
    ) -> list[list[int]] | torch.Tensor:
4362
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
4363
4364
4365
        spec_config = self.speculative_config
        assert spec_config is not None
        if spec_config.method == "ngram":
4366
            assert isinstance(sampled_token_ids, list)
4367
            assert isinstance(self.drafter, NgramProposer)
4368
            draft_token_ids = self.drafter.propose(
4369
                sampled_token_ids,
4370
4371
                self.input_batch.num_tokens_no_spec,
                self.input_batch.token_ids_cpu,
4372
                slot_mappings=slot_mappings,
4373
            )
4374
        elif spec_config.method == "suffix":
4375
4376
            assert isinstance(sampled_token_ids, list)
            assert isinstance(self.drafter, SuffixDecodingProposer)
4377
4378
4379
            draft_token_ids = self.drafter.propose(
                self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
            )
4380
        elif spec_config.method == "medusa":
4381
            assert isinstance(sampled_token_ids, list)
4382
            assert isinstance(self.drafter, MedusaProposer)
4383

4384
4385
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
4386
4387
4388
4389
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
4390
4391
4392
                assert spec_decode_metadata is not None, (
                    "No spec decode metadata for medusa"
                )
4393
                for num_draft, tokens in zip(
4394
4395
                    spec_decode_metadata.num_draft_tokens, sampled_token_ids
                ):
4396
4397
                    indices.append(offset + len(tokens) - 1)
                    offset += num_draft + 1
4398
                indices = torch.tensor(indices, device=self.device)
4399
4400
                hidden_states = sample_hidden_states[indices]

4401
            draft_token_ids = self.drafter.propose(
4402
4403
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
4404
                slot_mappings=slot_mappings,
4405
            )
4406
4407
        elif spec_config.use_eagle() or spec_config.uses_draft_model():
            assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
4408

4409
            if spec_config.disable_padded_drafter_batch:
4410
4411
4412
                # When padded-batch is disabled, the sampled_token_ids should be
                # the cpu-side list[list[int]] of valid sampled tokens for each
                # request, with invalid requests having empty lists.
4413
4414
                assert isinstance(sampled_token_ids, list), (
                    "sampled_token_ids should be a python list when"
4415
                    "padded-batch is disabled."
4416
                )
4417
                next_token_ids = self.drafter.prepare_next_token_ids_cpu(
4418
4419
4420
4421
4422
                    sampled_token_ids,
                    self.requests,
                    self.input_batch,
                    scheduler_output.num_scheduled_tokens,
                )
4423
4424
4425
4426
4427
            else:
                # When using padded-batch, the sampled_token_ids should be
                # the gpu tensor of sampled tokens for each request, of shape
                # (num_reqs, num_spec_tokens + 1) with rejected tokens having
                # value -1.
4428
4429
                assert isinstance(sampled_token_ids, torch.Tensor), (
                    "sampled_token_ids should be a torch.Tensor when"
4430
                    "padded-batch is enabled."
4431
4432
                )
                next_token_ids, valid_sampled_tokens_count = (
4433
4434
4435
4436
4437
                    self.drafter.prepare_next_token_ids_padded(
                        common_attn_metadata,
                        sampled_token_ids,
                        self.requests,
                        self.input_batch,
4438
                        self.discard_request_mask.gpu,
4439
                    )
4440
                )
4441
4442
4443
                self._copy_valid_sampled_token_count(
                    next_token_ids, valid_sampled_tokens_count
                )
Jiayi Yao's avatar
Jiayi Yao committed
4444

4445
            num_rejected_tokens_gpu = None
4446
            if spec_decode_metadata is None:
4447
                token_indices_to_sample = None
4448
                # input_ids can be None for multimodal models.
4449
                target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
4450
                target_positions = self._get_positions(num_scheduled_tokens)
4451
                if self.use_aux_hidden_state_outputs:
Wentao Ye's avatar
Wentao Ye committed
4452
                    assert aux_hidden_states is not None
4453
                    target_hidden_states = torch.cat(
4454
4455
                        [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
                    )
4456
4457
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
4458
            else:
4459
                if spec_config.disable_padded_drafter_batch:
4460
                    token_indices_to_sample = None
4461
4462
4463
4464
4465
                    common_attn_metadata, token_indices = self.drafter.prepare_inputs(
                        common_attn_metadata,
                        sampled_token_ids,
                        spec_decode_metadata.num_draft_tokens,
                    )
4466
4467
4468
4469
4470
4471
4472
4473
4474
                    target_token_ids = self.input_ids.gpu[token_indices]
                    target_positions = self._get_positions(token_indices)
                    if self.use_aux_hidden_state_outputs:
                        assert aux_hidden_states is not None
                        target_hidden_states = torch.cat(
                            [h[token_indices] for h in aux_hidden_states], dim=-1
                        )
                    else:
                        target_hidden_states = hidden_states[token_indices]
4475
                else:
4476
4477
4478
4479
4480
4481
4482
4483
                    (
                        common_attn_metadata,
                        token_indices_to_sample,
                        num_rejected_tokens_gpu,
                    ) = self.drafter.prepare_inputs_padded(
                        common_attn_metadata,
                        spec_decode_metadata,
                        valid_sampled_tokens_count,
4484
                    )
4485
4486
4487
4488
4489
4490
4491
4492
4493
4494
                    #total_num_tokens = common_attn_metadata.num_actual_tokens
                    if (
                        self.enable_lightly_cp
                        and common_attn_metadata.cp_common_metadata is not None
                    ):
                        total_num_tokens = (
                            common_attn_metadata.cp_common_metadata.num_actual_tokens
                        )
                    else:
                        total_num_tokens = common_attn_metadata.num_actual_tokens
4495
4496
4497
4498
4499
4500
4501
4502
4503
4504
                    # When padding the batch, token_indices is just a range
                    target_token_ids = self.input_ids.gpu[:total_num_tokens]
                    target_positions = self._get_positions(total_num_tokens)
                    if self.use_aux_hidden_state_outputs:
                        assert aux_hidden_states is not None
                        target_hidden_states = torch.cat(
                            [h[:total_num_tokens] for h in aux_hidden_states], dim=-1
                        )
                    else:
                        target_hidden_states = hidden_states[:total_num_tokens]
4505

4506
            if self.supports_mm_inputs:
4507
4508
4509
4510
4511
4512
                mm_embed_inputs = self._gather_mm_embeddings(
                    scheduler_output,
                    shift_computed_tokens=1,
                )
            else:
                mm_embed_inputs = None
4513

王敏's avatar
王敏 committed
4514
            draft_result = self.drafter.propose(
4515
4516
4517
4518
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                next_token_ids=next_token_ids,
4519
                last_token_indices=token_indices_to_sample,
4520
                sampling_metadata=sampling_metadata,
4521
                common_attn_metadata=common_attn_metadata,
4522
                mm_embed_inputs=mm_embed_inputs,
4523
                num_rejected_tokens_gpu=num_rejected_tokens_gpu,
4524
                slot_mappings=slot_mappings,
4525
            )
4526

王敏's avatar
王敏 committed
4527
4528
4529
4530
4531
4532
4533
4534
4535
4536
4537
4538
4539
            if not envs.VLLM_REJECT_SAMPLE_OPT:
                draft_token_ids = draft_result
            else:
                draft_token_ids, draft_probs = draft_result

            if envs.VLLM_REJECT_SAMPLE_OPT:
                draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
                if self.draft_probs is None:
                    self.draft_probs = DraftProbs(
                        draft_probs, draft_req_ids)
                else:
                    self.draft_probs.update(draft_probs, draft_req_ids)

4540
        return draft_token_ids
4541

4542
4543
4544
    def update_config(self, overrides: dict[str, Any]) -> None:
        allowed_config_names = {"load_config", "model_config"}
        for config_name, config_overrides in overrides.items():
4545
4546
            assert config_name in allowed_config_names, (
                f"Config `{config_name}` not supported. "
4547
                f"Allowed configs: {allowed_config_names}"
4548
            )
4549
4550
4551
4552
            config = getattr(self, config_name)
            new_config = update_config(config, config_overrides)
            setattr(self, config_name, new_config)

4553
4554
4555
4556
4557
    def load_model(self, eep_scale_up: bool = False) -> None:
        """
        Args:
            eep_scale_up: the model loading is for elastic EP scale up.
        """
4558
4559
4560
4561
4562
        logger.info_once(
            "Starting to load model %s...",
            self.model_config.model,
            scope="global",
        )
4563
4564
4565
4566
4567
        global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
            EplbState.get_eep_state(self.parallel_config)
            if eep_scale_up
            else (None, None, None)
        )
4568

4569
4570
4571
4572
        if self.parallel_config.enable_eplb:
            self.eplb_state = EplbState(self.parallel_config, self.device)
            eplb_models = 0

4573
4574
4575
4576
4577
4578
        try:
            with DeviceMemoryProfiler() as m:
                time_before_load = time.perf_counter()
                model_loader = get_model_loader(self.load_config)
                self.model = model_loader.load_model(
                    vllm_config=self.vllm_config, model_config=self.model_config
4579
                )
4580
4581
4582
                if self.lora_config:
                    self.model = self.load_lora_model(
                        self.model, self.vllm_config, self.device
4583
                    )
4584
4585
4586
4587
4588
4589
4590
4591
4592
4593
4594
4595
4596
4597
4598
                if hasattr(self, "drafter"):
                    logger.info_once("Loading drafter model...")
                    self.drafter.load_model(self.model)
                    if (
                        hasattr(self.drafter, "model")
                        and is_mixture_of_experts(self.drafter.model)
                        and self.parallel_config.enable_eplb
                    ):
                        spec_config = self.vllm_config.speculative_config
                        assert spec_config is not None
                        assert spec_config.draft_model_config is not None
                        logger.info_once(
                            "EPLB is enabled for drafter model %s.",
                            spec_config.draft_model_config.model,
                        )
4599

4600
4601
4602
4603
4604
4605
4606
4607
4608
4609
4610
4611
4612
4613
4614
4615
4616
4617
4618
4619
4620
4621
                        global_expert_load = (
                            global_expert_loads[eplb_models]
                            if global_expert_loads
                            else None
                        )
                        old_global_expert_indices = (
                            old_global_expert_indices_per_model[eplb_models]
                            if old_global_expert_indices_per_model
                            else None
                        )
                        if self.eplb_state is None:
                            self.eplb_state = EplbState(
                                self.parallel_config, self.device
                            )
                        self.eplb_state.add_model(
                            self.drafter.model,
                            spec_config.draft_model_config,
                            global_expert_load,
                            old_global_expert_indices,
                            rank_mapping,
                        )
                        eplb_models += 1
4622

4623
4624
4625
4626
4627
4628
                if self.use_aux_hidden_state_outputs:
                    if not supports_eagle3(self.get_model()):
                        raise RuntimeError(
                            "Model does not support EAGLE3 interface but "
                            "aux_hidden_state_outputs was requested"
                        )
4629

4630
4631
4632
4633
4634
4635
4636
4637
4638
4639
                    # Try to get auxiliary layers from speculative config,
                    # otherwise use model's default layers
                    aux_layers = self._get_eagle3_aux_layers_from_config()
                    if aux_layers:
                        logger.info(
                            "Using auxiliary layers from speculative config: %s",
                            aux_layers,
                        )
                    else:
                        aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
4640

4641
4642
4643
4644
4645
4646
4647
4648
4649
4650
4651
4652
4653
4654
                    self.model.set_aux_hidden_state_layers(aux_layers)
                time_after_load = time.perf_counter()
            self.model_memory_usage = m.consumed_memory
        except torch.cuda.OutOfMemoryError as e:
            msg = (
                "Failed to load model - not enough GPU memory. "
                "Try lowering --gpu-memory-utilization to free memory for weights, "
                "increasing --tensor-parallel-size, or using --quantization. "
                "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
                "for more tips."
            )
            combined_msg = f"{msg} (original error: {e})"
            logger.error(combined_msg)
            raise e
4655
        logger.info_once(
4656
4657
            "Model loading took %s GiB memory and %.6f seconds",
            format_gib(self.model_memory_usage),
4658
            time_after_load - time_before_load,
4659
            scope="local",
4660
        )
4661
        prepare_communication_buffer_for_model(self.model)
4662
4663
4664
4665
        if (drafter := getattr(self, "drafter", None)) and (
            drafter_model := getattr(drafter, "model", None)
        ):
            prepare_communication_buffer_for_model(drafter_model)
4666
        mm_config = self.model_config.multimodal_config
4667
        self.is_multimodal_pruning_enabled = (
4668
            supports_multimodal_pruning(self.get_model())
4669
4670
            and mm_config is not None
            and mm_config.is_multimodal_pruning_enabled()
4671
        )
4672

4673
        if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
4674
4675
4676
4677
4678
4679
4680
4681
4682
4683
4684
            logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
            global_expert_load = (
                global_expert_loads[eplb_models] if global_expert_loads else None
            )
            old_global_expert_indices = (
                old_global_expert_indices_per_model[eplb_models]
                if old_global_expert_indices_per_model
                else None
            )
            assert self.eplb_state is not None
            self.eplb_state.add_model(
4685
                self.model,
4686
                self.model_config,
4687
4688
4689
                global_expert_load,
                old_global_expert_indices,
                rank_mapping,
4690
            )
4691
4692
            if self.eplb_state.is_async:
                self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
4693

4694
        if (
4695
4696
            self.vllm_config.compilation_config.mode
            == CompilationMode.STOCK_TORCH_COMPILE
4697
        ):
4698
            backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
4699
            compilation_counter.stock_torch_compile_count += 1
4700
            self.model.compile(fullgraph=True, backend=backend)
4701
            return
4702
        # for other compilation modes, cudagraph behavior is controlled by
4703
4704
4705
        # CudagraphWraper and CudagraphDispatcher of vllm.

        # wrap the model with full cudagraph wrapper if needed.
4706
4707
        cudagraph_mode = self.compilation_config.cudagraph_mode
        assert cudagraph_mode is not None
4708
4709
4710
4711
        if (
            cudagraph_mode.has_full_cudagraphs()
            and not self.parallel_config.use_ubatching
        ):
4712
4713
4714
            self.model = CUDAGraphWrapper(
                self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
            )
4715
        elif self.parallel_config.use_ubatching:
4716
            if cudagraph_mode.has_full_cudagraphs():
4717
4718
4719
                self.model = UBatchWrapper(
                    self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
                )
4720
            else:
4721
4722
4723
                self.model = UBatchWrapper(
                    self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
                )
4724

4725
    def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
4726
4727
4728
4729
4730
4731
4732
4733
4734
4735
4736
4737
4738
4739
4740
4741
4742
4743
4744
4745
4746
4747
        """Extract Eagle3 auxiliary layer indices from speculative config.

        These indices specify which hidden states from the base model should
        be used as auxiliary inputs for the Eagle3 drafter model during
        speculative decoding.

        Returns:
            Tuple of layer indices if found in draft model config,
            None otherwise.
        """
        if not (self.speculative_config and self.speculative_config.draft_model_config):
            return None

        hf_config = self.speculative_config.draft_model_config.hf_config
        if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
            return None

        layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
        if layer_ids and isinstance(layer_ids, (list, tuple)):
            return tuple(layer_ids)

        return None
4748

4749
    def reload_weights(self) -> None:
4750
        assert getattr(self, "model", None) is not None, (
4751
            "Cannot reload weights before model is loaded."
4752
        )
4753
4754
        model_loader = get_model_loader(self.load_config)
        logger.info("Reloading weights inplace...")
4755
        model_loader.load_weights(self.get_model(), model_config=self.model_config)
4756

4757
4758
4759
4760
4761
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        TensorizerLoader.save_model(
4762
            self.get_model(),
4763
            tensorizer_config=tensorizer_config,
4764
            model_config=self.model_config,
4765
4766
        )

4767
4768
4769
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
4770
        num_scheduled_tokens: dict[str, int],
4771
    ) -> dict[str, LogprobsTensors | None]:
4772
        num_prompt_logprobs_dict = self.num_prompt_logprobs
4773
4774
4775
        if not num_prompt_logprobs_dict:
            return {}

4776
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
4777
        prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
4778
4779
4780
4781
4782

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
4783
4784
4785
4786
            num_tokens = num_scheduled_tokens.get(req_id)
            if num_tokens is None:
                # This can happen if the request was preempted in prefill stage.
                continue
4787
4788
4789

            # Get metadata for this request.
            request = self.requests[req_id]
4790
4791
4792
4793
            if request.prompt_token_ids is None:
                # Prompt logprobs is incompatible with prompt embeddings
                continue

4794
4795
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
4796
4797
                self.device, non_blocking=True
            )
4798

4799
4800
4801
4802
4803
4804
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
4805
4806
                    num_prompt_tokens - 1, num_prompt_logprobs + 1
                )
4807
4808
                in_progress_dict[req_id] = logprobs_tensors

4809
            # Determine number of logits to retrieve.
4810
4811
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
4812
            num_remaining_tokens = num_prompt_tokens - start_tok
4813
            if num_tokens <= num_remaining_tokens:
4814
                # This is a chunk, more tokens remain.
4815
4816
4817
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
4818
4819
4820
4821
4822
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
4823
4824
4825
4826
4827
4828
4829
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
4830
4831
4832
4833
4834

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
4835
            offset = self.query_start_loc.np[req_idx].item()
4836
            prompt_hidden_states = hidden_states[offset : offset + num_logits]
4837
            logits = self.model.compute_logits(prompt_hidden_states)
4838
4839
4840
4841

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
4842
            tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits]
4843
4844

            # Compute prompt logprobs.
4845
4846
            logprobs = self.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.sampler.gather_logprobs(
4847
4848
                logprobs, num_prompt_logprobs, tgt_token_ids
            )
4849
4850

            # Transfer GPU->CPU async.
4851
4852
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
4853
4854
4855
                token_ids, non_blocking=True
            )
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True)
4856
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
4857
4858
                ranks, non_blocking=True
            )
4859
4860
4861
4862
4863

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
4864
            del in_progress_dict[req_id]
4865
4866

        # Must synchronize the non-blocking GPU->CPU transfers.
4867
        if prompt_logprobs_dict:
4868
            self._sync_device()
4869
4870
4871

        return prompt_logprobs_dict

4872
4873
    def _get_nans_in_logits(
        self,
4874
        logits: torch.Tensor | None,
4875
4876
4877
4878
4879
4880
4881
4882
4883
4884
4885
    ) -> dict[str, int]:
        try:
            if logits is None:
                return {req_id: 0 for req_id in self.input_batch.req_ids}

            num_nans_in_logits = {}
            num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
            for req_id in self.input_batch.req_ids:
                req_index = self.input_batch.req_id_to_index[req_id]
                num_nans_in_logits[req_id] = (
                    int(num_nans_for_index[req_index])
4886
4887
4888
                    if num_nans_for_index is not None and req_index < logits.shape[0]
                    else 0
                )
4889
4890
4891
4892
            return num_nans_in_logits
        except IndexError:
            return {}

4893
    @contextmanager
4894
4895
4896
    def maybe_randomize_inputs(
        self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None
    ):
4897
4898
4899
4900
        """
        Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
        This is to help balance expert-selection
         - during profile_run
4901
         - during DP rank dummy run
4902
        """
4903

4904
4905
4906
4907
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
        if not randomize_inputs:
            yield
4908
        elif input_ids is not None:
4909
4910
4911
4912

            @functools.cache
            def rand_input_ids() -> torch.Tensor:
                return torch.randint_like(
4913
                    self.input_ids.gpu,
4914
4915
                    low=0,
                    high=self.model_config.get_vocab_size(),
4916
                )
4917

4918
            logger.debug_once("Randomizing dummy input_ids for DP Rank")
4919
            input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True)
4920
4921
            yield
            input_ids.fill_(0)
4922
4923
4924
4925
4926
4927
4928
4929
4930
4931
4932
4933
4934
4935
4936
        else:

            @functools.cache
            def rand_inputs_embeds() -> torch.Tensor:
                return torch.randn_like(
                    self.inputs_embeds.gpu,
                )

            assert inputs_embeds is not None
            logger.debug_once("Randomizing dummy inputs_embeds for DP Rank")
            inputs_embeds.copy_(
                rand_inputs_embeds()[: inputs_embeds.size(0)], non_blocking=True
            )
            yield
            inputs_embeds.fill_(0)
4937

4938
4939
4940
4941
4942
4943
    def _get_mm_dummy_batch(
        self,
        modality: str,
        max_items_per_batch: int,
    ) -> BatchedTensorInputs:
        """Dummy data for profiling and precompiling multimodal models."""
4944
4945
        assert self.mm_budget is not None

4946
4947
4948
        # Don't use `max_items_per_batch` here to avoid redundant computation
        dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs(
            self.model_config,
4949
            mm_counts={modality: 1},
4950
            cache=self.mm_budget.cache,
4951
        )
4952
4953
4954
4955
4956
        dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0]

        # We use the cache so that the item is saved to the cache,
        # but not read from the cache
        assert dummy_mm_item is not None, "Item should not already be cached"
4957

4958
        dummy_mm_items = [dummy_mm_item] * max_items_per_batch
4959

4960
4961
4962
4963
4964
4965
4966
4967
        return next(
            mm_kwargs_group
            for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
                dummy_mm_items,
                device=self.device,
                pin_memory=self.pin_memory,
            )
        )
4968

4969
4970
4971
4972
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
4973
        cudagraph_runtime_mode: CUDAGraphMode | None = None,
4974
4975
        force_attention: bool = False,
        uniform_decode: bool = False,
4976
        allow_microbatching: bool = True,
4977
4978
        skip_eplb: bool = False,
        is_profile: bool = False,
4979
        create_mixed_batch: bool = False,
4980
        remove_lora: bool = True,
4981
        activate_lora: bool = False,
Rémi Delacourt's avatar
Rémi Delacourt committed
4982
        is_graph_capturing: bool = False,
4983
    ) -> tuple[torch.Tensor, torch.Tensor]:
4984
4985
4986
4987
4988
4989
4990
        """
        Run a dummy forward pass to warm up/profile run or capture the
        CUDA graph for the model.

        Args:
            num_tokens: Number of tokens to run the dummy forward pass.
            cudagraph_runtime_mode: used to control the behavior.
4991
                - if not set will determine the cudagraph mode based on using
4992
                    the self.cudagraph_dispatcher.
4993
4994
4995
4996
                - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
                - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
                - CUDAGraphMode.FULL: Full cudagraph, attention metadata is
                    needed.
4997
            force_attention: If True, always create attention metadata. Used to
4998
4999
5000
5001
                warm up attention backend when mode is NONE.
            uniform_decode: If True, the batch is a uniform decode batch.
            skip_eplb: If True, skip EPLB state update.
            is_profile: If True, this is a profile run.
5002
5003
            create_mixed_batch: If True, create a mixed batch with both decode
                (1 token) and prefill (multiple tokens) requests.
5004
            remove_lora: If False, dummy LoRAs are not destroyed after the run
5005
            activate_lora: If False, dummy_run is performed without LoRAs.
5006
        """
5007
5008
        mm_config = self.vllm_config.model_config.multimodal_config
        if mm_config and mm_config.mm_encoder_only:
5009
5010
5011
5012
            # The current dummy run only covers LM execution, so we can skip it.
            # mm encoder dummy run may need to add in the future.
            return torch.tensor([]), torch.tensor([])

5013
5014
5015
5016
        assert (
            cudagraph_runtime_mode is None
            or cudagraph_runtime_mode.valid_runtime_modes()
        )
5017

5018
        # If cudagraph_mode.decode_mode() == FULL and
5019
        # cudagraph_mode.separate_routine(). This means that we are using
5020
5021
5022
5023
5024
5025
5026
5027
5028
5029
5030
        # different graphs and/or modes for mixed prefill-decode batches vs.
        # uniform decode batches. A uniform decode batch means that all
        # requests have identical query length, except a potential virtual
        # request (shorter) in the batch account for padding.
        # Uniform decode batch could either be common pure decode, where
        # max_query_len == 1, or speculative decode, where
        # max_query_len == 1 + num_spec_decode_tokens.

        # When setting max_query_len = 1, we switch to and capture the optimized
        # routine of FA2 for pure decode, i.e., Flashdecode + an optimization
        # for GQA/MQA.
5031
        max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens
5032

5033
5034
5035
5036
5037
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
5038
5039
5040
5041
        if create_mixed_batch:
            assert not uniform_decode
            # Create mixed batch:
            # first half decode tokens, second half one prefill
5042
            num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2)
5043
5044
5045
5046
            num_prefill_tokens = num_tokens - num_decode_tokens
            num_reqs = num_decode_tokens + 1

            # Create decode requests (1 token each) followed by prefill request
5047
            num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens]
5048
5049
5050
            # Note: Overriding max_query_len to be the prefill tokens
            max_query_len = num_prefill_tokens
        elif uniform_decode:
5051
            assert not create_mixed_batch
5052
            num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
5053
5054
            num_scheduled_tokens_list = [max_query_len] * num_reqs
            if num_tokens % max_query_len != 0:
5055
                num_scheduled_tokens_list[-1] = num_tokens % max_query_len
5056
5057
5058
5059
5060
5061
        else:
            num_reqs = min(num_tokens, max_num_reqs)
            min_tokens_per_req = num_tokens // num_reqs
            num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
            num_scheduled_tokens_list[-1] += num_tokens % num_reqs

5062
5063
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
5064
        num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
5065
5066
        num_tokens_unpadded = int(num_scheduled_tokens.sum())

5067
        num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
5068

5069
        _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
5070
5071
5072
5073
5074
5075
5076
5077
5078
5079
5080
5081
5082
5083
5084
5085
5086
5087
            self._determine_batch_execution_and_padding(
                num_tokens=num_tokens_unpadded,
                num_reqs=num_reqs,
                num_scheduled_tokens_np=num_scheduled_tokens,
                max_num_scheduled_tokens=max_query_len,
                use_cascade_attn=False,
                allow_microbatching=allow_microbatching,
                force_eager=is_profile
                or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
                # `force_uniform_decode` is used for cudagraph capture; because for
                # capturing mixed prefill-decode batches, we sometimes use
                # num_tokens == num_reqs which looks like a uniform decode batch to the
                # dispatcher; but we actually want to capture a piecewise cudagraph
                force_uniform_decode=uniform_decode,
                # `force_has_lora` is used for cudagraph capture; because LoRA is
                # activated later in the context manager, but we need to know the
                # LoRA state when determining the batch descriptor for capture
                force_has_lora=activate_lora,
5088
            )
5089
        )
5090
5091
5092

        if cudagraph_runtime_mode is None:
            cudagraph_runtime_mode = _cudagraph_mode
5093
        else:
5094
5095
5096
5097
            assert cudagraph_runtime_mode == _cudagraph_mode, (
                f"Cudagraph runtime mode mismatch in dummy_run. "
                f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}."
            )
5098

5099
5100
5101
5102
        num_tokens_padded = batch_desc.num_tokens
        num_reqs_padded = (
            batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
        )
5103
        ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
5104
5105
5106
5107
5108
5109
5110
5111
5112
5113
            should_ubatch,
            num_scheduled_tokens,
            num_tokens_padded,
            num_reqs_padded,
            self.vllm_config.parallel_config.num_ubatches,
        )
        logger.debug(
            "ubatch_slices: %s, ubatch_slices_padded: %s",
            ubatch_slices,
            ubatch_slices_padded,
5114
        )
5115

5116
        attn_metadata: PerLayerAttnMetadata | None = None
5117

5118
5119
5120
5121
5122
5123
5124
        slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
            num_tokens_padded=num_tokens,
            num_reqs_padded=num_reqs_padded,
            num_tokens_unpadded=num_tokens_unpadded,
            ubatch_slices=ubatch_slices_padded,
        )

5125
5126
        # If force_attention is True, we always capture attention. Otherwise,
        # it only happens for cudagraph_runtime_mode=FULL.
5127
        if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
5128
5129
5130
5131
5132
5133
            if create_mixed_batch:
                # In the mixed batch mode (used for FI warmup), we use
                # shorter sequence lengths to run faster.
                # TODO(luka) better system for describing dummy batches
                seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
            else:
5134
5135
5136
5137
5138
                if not envs.VLLM_USE_PIECEWISE:
                    seq_lens = max_query_len
                else:
                    # Make sure max_model_len is used at the graph capture time.
                    seq_lens = self.max_model_len
5139
            self.seq_lens.np[:num_reqs] = seq_lens
5140
5141
            self.seq_lens.np[num_reqs:] = 0
            self.seq_lens.copy_to_gpu()
5142

5143
5144
            cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
            self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
5145
5146
            self.query_start_loc.copy_to_gpu()

5147
            pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
5148
            attn_metadata, _, _, _ = self._build_attention_metadata(
5149
5150
5151
                num_tokens=num_tokens_unpadded,
                num_reqs=num_reqs_padded,
                max_query_len=max_query_len,
5152
                ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
5153
                for_cudagraph_capture=is_graph_capturing,
5154
                slot_mappings=slot_mappings_by_group,
5155
            )
5156

5157
        with self.maybe_dummy_run_with_lora(
5158
5159
5160
5161
5162
            self.lora_config,
            num_scheduled_tokens,
            num_sampled_tokens,
            activate_lora,
            remove_lora,
5163
        ):
5164
            # Make sure padding doesn't exceed max_num_tokens
5165
            assert num_tokens_padded <= self.max_num_tokens
5166
            model_kwargs = self._init_model_kwargs()
5167
            if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
Patrick von Platen's avatar
Patrick von Platen committed
5168
5169
                input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded)

5170
                model_kwargs = {
5171
                    **model_kwargs,
5172
5173
                    **self._dummy_mm_kwargs(num_reqs),
                }
5174
5175
            elif self.enable_prompt_embeds:
                input_ids = None
5176
                inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
5177
                model_kwargs = self._init_model_kwargs()
5178
            else:
5179
                input_ids = self.input_ids.gpu[:num_tokens_padded]
5180
                inputs_embeds = None
5181

guanyu1's avatar
guanyu1 committed
5182
            positions = self._get_positions(num_tokens_padded)
5183
5184
5185
5186
5187
5188
5189
5190
5191

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
5192
5193
5194
                            device=self.device,
                        )
                    )
5195
5196

                intermediate_tensors = self.sync_and_slice_intermediate_tensors(
5197
                    num_tokens_padded, None, False
5198
                )
5199

5200
            if ubatch_slices_padded is not None:
5201
5202
5203
                # Adjust values to reflect a single ubatch.
                # TODO(sage,lucas): this is cruft that should be addressed in
                #  the padding refactor.
5204
                num_tokens_padded = ubatch_slices_padded[0].num_tokens
5205
                if num_tokens_across_dp is not None:
5206
                    num_tokens_across_dp[:] = num_tokens_padded
5207

5208
            with (
5209
                self.maybe_randomize_inputs(input_ids, inputs_embeds),
5210
                set_forward_context(
5211
5212
                    attn_metadata,
                    self.vllm_config,
5213
                    num_tokens=num_tokens_padded,
5214
5215
                    num_tokens_across_dp=num_tokens_across_dp,
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
5216
                    batch_descriptor=batch_desc,
5217
                    ubatch_slices=ubatch_slices_padded,
5218
                    slot_mapping=slot_mappings,
5219
5220
                    enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
                    enable_lightly_cplb=self.enable_lightly_cplb
5221
5222
                ),
            ):
5223
                outputs = self.model(
5224
5225
5226
5227
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
5228
                    **model_kwargs,
5229
                )
5230

5231
5232
5233
5234
            if self.use_aux_hidden_state_outputs:
                hidden_states, _ = outputs
            else:
                hidden_states = outputs
5235

5236
5237
5238
5239
            if self.speculative_config and (
                self.speculative_config.use_eagle()
                or self.speculative_config.uses_draft_model()
            ):
王敏's avatar
王敏 committed
5240
5241
5242
5243
5244
5245
5246
5247
5248
5249
5250
5251
5252
5253
5254
5255
5256
5257
5258
5259
5260
5261
5262
5263
5264
5265
5266
5267
5268
                #assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
                if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer | DraftModelProposer):
                    assert self.speculative_config is not None
                    # Eagle currently only supports PIECEWISE cudagraphs.
                    # Therefore only use cudagraphs if the main model uses PIECEWISE
                    # NOTE(lucas): this is a hack, need to clean up.
                    use_cudagraphs = (
                        (
                            is_graph_capturing
                            and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
                        )
                        or (
                            not is_graph_capturing
                            and cudagraph_runtime_mode != CUDAGraphMode.NONE
                        )
                    ) and not self.speculative_config.enforce_eager

                    # Note(gnovack) - We need to disable cudagraphs for one of the two
                    # lora cases when cudagraph_specialize_lora is enabled. This is a
                    # short term mitigation for issue mentioned in
                    # https://github.com/vllm-project/vllm/issues/28334
                    if self.compilation_config.cudagraph_specialize_lora and activate_lora:
                        use_cudagraphs = False

                    self.drafter.dummy_run(
                        num_tokens,
                        use_cudagraphs=use_cudagraphs,
                        is_graph_capturing=is_graph_capturing,
                        slot_mappings=slot_mappings,
5269
                    )
5270

5271
5272
5273
5274
5275
5276
5277
5278
5279
5280
5281
        # We register layerwise NVTX hooks here after the first dynamo tracing is
        # done to avoid nvtx operations in hook functions being traced by
        # torch dynamo and causing graph breaks.
        # Note that for DYNAMO_ONCE and VLLM_COMPILE mode,
        # compiled model's dynamo tracing is only done once and the compiled model's
        # __call__ function is replaced by calling the compiled function.
        # So it's safe to register hooks here. Hooks will be registered to
        # both compiled and uncompiled models but they will never
        # be called on the compiled model execution path.
        self._register_layerwise_nvtx_hooks()

5282
5283
5284
5285
5286
5287
5288
5289
5290
5291
        # This is necessary to avoid blocking DP.
        # For dummy runs, we typically skip EPLB since we don't have any real
        # requests to process.
        # However, in DP settings, there may be cases when some DP ranks do
        # not have any requests to process, so they're executing dummy batches.
        # In such cases, we still have to trigger EPLB to make sure
        # ranks execute the rearrangement in synchronization.
        if not skip_eplb:
            self.eplb_step(is_dummy=True, is_profile=is_profile)

5292
        logit_indices = np.cumsum(num_scheduled_tokens) - 1
5293
5294
5295
5296
5297
5298
5299
5300
5301
        # logit_indices_device = torch.from_numpy(logit_indices).to(
        #     self.device, non_blocking=True
        # )
        logit_indices = logit_indices.tolist()
        logit_indices_device = async_tensor_h2d(
                    logit_indices,
                    dtype=torch.int32,
                    target_device=self.device,
                    pin_memory=True)
5302
        return hidden_states, hidden_states[logit_indices_device]
5303
5304
5305
5306
5307
5308

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
5309
5310
5311
        # The dummy hidden states may contain special values,
        # like `inf` or `nan`.
        # To avoid breaking the sampler, we use a random tensor here instead.
5312

5313
5314
        mm_config = self.vllm_config.model_config.multimodal_config
        if mm_config and mm_config.mm_encoder_only:
5315
5316
5317
            # MM Encoder only model no need to run sampler.
            return torch.tensor([])

5318
        hidden_states = torch.rand_like(hidden_states)
5319

5320
        logits = self.model.compute_logits(hidden_states)
5321
5322
        num_reqs = logits.size(0)

5323
        dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device)
5324
5325
5326
5327
5328
5329
5330
5331
5332
5333
5334
5335
5336
5337
5338

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
5339
            spec_token_ids=[[] for _ in range(num_reqs)],
5340
5341
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
5342
            logitsprocs=LogitsProcessors(),
5343
        )
5344
        try:
5345
5346
5347
            sampler_output = self.sampler(
                logits=logits, sampling_metadata=dummy_metadata
            )
5348
        except RuntimeError as e:
5349
            if "out of memory" in str(e):
5350
5351
5352
5353
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
5354
5355
                    "initializing the engine."
                ) from e
5356
5357
            else:
                raise e
5358
        if self.speculative_config:
5359
5360
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
5361
5362
                draft_token_ids, self.device
            )
5363
5364

            num_tokens = sum(len(ids) for ids in draft_token_ids)
5365
5366
5367
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
王敏's avatar
王敏 committed
5368
5369
5370
5371
5372
5373
5374
            
            if not envs.VLLM_REJECT_SAMPLE_OPT:
                draft_probs = None
            else:
                draft_probs = torch.randn(
                    num_reqs, self.speculative_config.num_speculative_tokens, logits.shape[-1], device=self.device,
                    dtype=logits.dtype)
5375
                dummy_metadata.all_greedy = True
王敏's avatar
王敏 committed
5376

5377
5378
5379
5380
5381
            logits = torch.randn(
                num_tokens + num_reqs,
                logits.shape[-1],
                device=self.device,
                dtype=logits.dtype,
5382
            )
5383
5384
5385
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
5386
                logits,
5387
5388
                dummy_metadata,
            )
5389
        return sampler_output
5390

5391
    def _dummy_pooler_run_task(
5392
5393
        self,
        hidden_states: torch.Tensor,
5394
5395
        task: PoolingTask,
    ) -> PoolerOutput:
5396
5397
5398
5399
        num_tokens = hidden_states.shape[0]
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = min(num_tokens, max_num_reqs)
        min_tokens_per_req = num_tokens // num_reqs
5400
5401
5402
5403
        num_scheduled_tokens_np = np.full(num_reqs, min_tokens_per_req)
        num_scheduled_tokens_np[-1] += num_tokens % num_reqs
        assert np.sum(num_scheduled_tokens_np) == num_tokens
        assert len(num_scheduled_tokens_np) == num_reqs
5404
5405
5406

        req_num_tokens = num_tokens // num_reqs

5407
        dummy_prompt_lens = torch.from_numpy(num_scheduled_tokens_np)
5408
5409
5410
        dummy_token_ids = torch.zeros(
            (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device
        )
5411

5412
        model = cast(VllmModelForPooling, self.get_model())
5413
        dummy_pooling_params = PoolingParams(task=task)
5414
        dummy_pooling_params.verify(task=task, model_config=self.model_config)
5415
        to_update = model.pooler.get_pooling_updates(task)
5416
5417
        to_update.apply(dummy_pooling_params)

5418
        dummy_metadata = PoolingMetadata(
5419
5420
5421
            prompt_lens=dummy_prompt_lens,
            prompt_token_ids=dummy_token_ids,
            pooling_params=[dummy_pooling_params] * num_reqs,
5422
            pooling_states=[PoolingStates() for i in range(num_reqs)],
5423
        )
5424

5425
        dummy_metadata.build_pooling_cursor(
5426
            num_scheduled_tokens_np,
5427
5428
            seq_lens_cpu=dummy_prompt_lens,
            device=hidden_states.device,
5429
        )
5430

5431
        try:
5432
5433
5434
            return model.pooler(
                hidden_states=hidden_states, pooling_metadata=dummy_metadata
            )
5435
        except RuntimeError as e:
5436
            if "out of memory" in str(e):
5437
                raise RuntimeError(
5438
5439
5440
                    "CUDA out of memory occurred when warming up pooler "
                    f"({task=}) with {num_reqs} dummy requests. Please try "
                    "lowering `max_num_seqs` or `gpu_memory_utilization` when "
5441
5442
                    "initializing the engine."
                ) from e
5443
5444
            else:
                raise e
5445
5446
5447
5448
5449
5450

    @torch.inference_mode()
    def _dummy_pooler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> PoolerOutput:
5451
5452
        mm_config = self.vllm_config.model_config.multimodal_config
        if mm_config and mm_config.mm_encoder_only:
5453
5454
5455
            # MM Encoder only model not need to run pooler.
            return torch.tensor([])

5456
        # Find the task that has the largest output for subsequent steps
5457
5458
5459
        supported_pooling_tasks = self.get_supported_pooling_tasks()

        if not supported_pooling_tasks:
5460
5461
5462
5463
5464
5465
            raise RuntimeError(
                f"Model {self.model_config.model} does not support "
                "any pooling tasks. See "
                "https://docs.vllm.ai/en/latest/models/pooling_models.html "
                "to learn more."
            )
5466

5467
        output_size = dict[PoolingTask, float]()
5468
        for task in supported_pooling_tasks:
5469
5470
            # Run a full batch with each task to ensure none of them OOMs
            output = self._dummy_pooler_run_task(hidden_states, task)
5471
            output_size[task] = sum(o.nbytes for o in output if o is not None)
5472
5473
5474
5475
            del output  # Allow GC

        max_task = max(output_size.items(), key=lambda x: x[1])[0]
        return self._dummy_pooler_run_task(hidden_states, max_task)
5476

5477
    def profile_run(self) -> None:
5478
        # Profile with multimodal encoder & encoder cache.
5479
        if self.supports_mm_inputs:
5480
5481
            mm_config = self.model_config.multimodal_config
            if mm_config is not None and mm_config.skip_mm_profiling:
5482
                logger.info(
5483
                    "Skipping memory profiling for multimodal encoder and "
5484
5485
                    "encoder cache."
                )
5486
5487
5488
5489
5490
5491
5492
5493
            else:
                mm_budget = self.mm_budget
                assert mm_budget is not None

                if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
                    # NOTE: Currently model is profiled with a single non-text
                    # modality with the max possible input tokens even when
                    # it supports multiple.
5494
                    dummy_modality = mm_budget.get_modality_with_max_tokens()
5495
5496
5497
                    max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
                        dummy_modality
                    ]
5498
5499
5500
5501
5502
5503
5504
5505
5506

                    logger.info(
                        "Encoder cache will be initialized with a budget of "
                        "%s tokens, and profiled with %s %s items of the "
                        "maximum feature size.",
                        encoder_budget,
                        max_mm_items_per_batch,
                        dummy_modality,
                    )
5507

5508
5509
5510
5511
5512
                    # Create dummy batch of multimodal inputs.
                    batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                        dummy_modality,
                        max_mm_items_per_batch,
                    )
5513

5514
                    # Run multimodal encoder.
5515
                    dummy_encoder_outputs = self.model.embed_multimodal(
5516
5517
                        **batched_dummy_mm_inputs
                    )
5518

5519
5520
5521
5522
                    sanity_check_mm_encoder_outputs(
                        dummy_encoder_outputs,
                        expected_num_items=max_mm_items_per_batch,
                    )
5523
5524
                    for i, output in enumerate(dummy_encoder_outputs):
                        self.encoder_cache[f"tmp_{i}"] = output
5525

5526
        # Add `is_profile` here to pre-allocate communication buffers
5527
5528
5529
        hidden_states, last_hidden_states = self._dummy_run(
            self.max_num_tokens, is_profile=True
        )
5530
        if get_pp_group().is_last_rank:
5531
5532
5533
5534
            if self.is_pooling_model:
                output = self._dummy_pooler_run(hidden_states)
            else:
                output = self._dummy_sampler_run(last_hidden_states)
5535
        else:
5536
            output = None
5537
        self._sync_device()
5538
        del hidden_states, output
5539
        self.encoder_cache.clear()
5540
        gc.collect()
5541

5542
    def capture_model(self) -> int:
5543
        if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
5544
            logger.warning(
5545
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
5546
5547
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
5548
            return 0
5549

5550
5551
        compilation_counter.num_gpu_runner_capture_triggers += 1

5552
5553
        start_time = time.perf_counter()

5554
5555
5556
5557
5558
5559
5560
5561
5562
5563
5564
5565
5566
5567
        @contextmanager
        def freeze_gc():
            # Optimize garbage collection during CUDA graph capture.
            # Clean up, then freeze all remaining objects from being included
            # in future collections.
            gc.collect()
            should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
            if should_freeze:
                gc.freeze()
            try:
                yield
            finally:
                if should_freeze:
                    gc.unfreeze()
5568
                    gc.collect()
5569

5570
5571
5572
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
5573
        set_cudagraph_capturing_enabled(True)
5574
        with freeze_gc(), graph_capture(device=self.device):
5575
            start_free_gpu_memory = torch.cuda.mem_get_info()[0]
5576

5577
5578
5579
5580
            for (
                runtime_mode,
                batch_descs,
            ) in self.cudagraph_dispatcher.get_capture_descs():
5581
                self._capture_cudagraphs(
5582
5583
                    batch_descriptors=batch_descs,
                    cudagraph_runtime_mode=runtime_mode,
5584
                )
5585

5586
5587
            torch.cuda.synchronize()
            end_free_gpu_memory = torch.cuda.mem_get_info()[0]
5588
5589
5590
5591

        # Disable cudagraph capturing globally, so any unexpected cudagraph
        # capturing will be detected and raise an error after here.
        # Note: We don't put it into graph_capture context manager because
5592
        # we may do lazy capturing in future that still allows capturing
5593
5594
        # after here.
        set_cudagraph_capturing_enabled(False)
5595

5596
5597
5598
5599
        # Lock workspace to prevent resizing during execution.
        # Max workspace sizes should have been captured during warmup/profiling.
        lock_workspace()

5600
5601
5602
5603
        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
5604
        logger.info_once(
5605
5606
5607
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
5608
            scope="local",
5609
        )
5610
        return cuda_graph_size
5611

5612
5613
    def _capture_cudagraphs(
        self,
5614
        batch_descriptors: list[BatchDescriptor],
5615
5616
5617
5618
5619
5620
        cudagraph_runtime_mode: CUDAGraphMode,
    ):
        assert (
            cudagraph_runtime_mode != CUDAGraphMode.NONE
            and cudagraph_runtime_mode.valid_runtime_modes()
        ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
5621

5622
5623
5624
5625
5626
5627
5628
5629
5630
5631
5632
5633
5634
5635
        if not batch_descriptors:
            return

        uniform_decode = batch_descriptors[0].uniform
        force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL

        dummy_run = functools.partial(
            self._dummy_run,
            uniform_decode=uniform_decode,
            skip_eplb=True,
            remove_lora=False,
            force_attention=force_attention,
        )

5636
5637
        # Only rank 0 should print progress bar during capture
        if is_global_first_rank():
5638
5639
            batch_descriptors = tqdm(
                batch_descriptors,
5640
5641
5642
                disable=not self.load_config.use_tqdm_on_load,
                desc="Capturing CUDA graphs ({}, {})".format(
                    "decode" if uniform_decode else "mixed prefill-decode",
5643
5644
5645
                    cudagraph_runtime_mode.name,
                ),
            )
5646

5647
        # We skip EPLB here since we don't want to record dummy metrics
5648
5649
5650
5651
        for batch_desc in batch_descriptors:
            num_tokens = batch_desc.num_tokens
            activate_lora = batch_desc.has_lora

5652
            # We currently only capture ubatched graphs when its a FULL
5653
5654
5655
            # cudagraph, a uniform decode batch, and the number of tokens
            # is above the threshold. Otherwise we just capture a non-ubatched
            # version of the graph
5656
            allow_microbatching = (
5657
                self.parallel_config.use_ubatching
5658
5659
                and cudagraph_runtime_mode == CUDAGraphMode.FULL
                and uniform_decode
5660
5661
5662
5663
5664
                and check_ubatch_thresholds(
                    config=self.vllm_config.parallel_config,
                    num_tokens=num_tokens,
                    uniform_decode=uniform_decode,
                )
5665
            )
5666

5667
5668
            for _ in range(self.compilation_config.cudagraph_num_of_warmups):
                # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
5669
                # But be careful, warm up with `NONE` is orthogonal to
5670
5671
5672
                # if we want to warm up attention or not. This is
                # different from the case where `FULL` implies capture
                # attention while `PIECEWISE` implies no attention.
5673
                dummy_run(
5674
5675
5676
                    num_tokens,
                    cudagraph_runtime_mode=CUDAGraphMode.NONE,
                    allow_microbatching=allow_microbatching,
5677
                    activate_lora=activate_lora,
5678
                )
5679
5680
5681

            # Capture run
            dummy_run(
5682
5683
5684
                num_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
                allow_microbatching=allow_microbatching,
5685
                activate_lora=activate_lora,
Rémi Delacourt's avatar
Rémi Delacourt committed
5686
                is_graph_capturing=True,
5687
            )
5688
        self.maybe_remove_all_loras(self.lora_config)
5689

5690
5691
5692
5693
    def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize the attention backends and attention metadata builders.
        """
5694
        assert len(self.attn_groups) == 0, "Attention backends are already initialized"
5695

5696
5697
5698
5699
5700
5701
        class AttentionGroupKey(NamedTuple):
            attn_backend: type[AttentionBackend]
            kv_cache_spec: KVCacheSpec

        def get_attn_backends_for_group(
            kv_cache_group_spec: KVCacheGroupSpec,
5702
        ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
5703
            layer_type = cast(type[Any], AttentionLayerBase)
5704
            layers = get_layers_from_vllm_config(
5705
                self.vllm_config, layer_type, kv_cache_group_spec.layer_names
5706
            )
5707
5708
            attn_backends = {}
            attn_backend_layers = defaultdict(list)
5709
            # Dedupe based on full class name; this is a bit safer than
5710
5711
5712
5713
            # using the class itself as the key because when we create dynamic
            # attention backend subclasses (e.g. ChunkedLocalAttention) unless
            # they are cached correctly, there will be different objects per
            # layer.
5714
            for layer_name in kv_cache_group_spec.layer_names:
5715
                attn_backend = layers[layer_name].get_attn_backend()
5716
5717
5718
5719

                if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
                    attn_backend = create_fast_prefill_custom_backend(
                        "FastPrefill",
5720
                        attn_backend,  # type: ignore[arg-type]
5721
5722
                    )

5723
5724
5725
                full_cls_name = attn_backend.full_cls_name()
                layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
                if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
5726
                    layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
5727
                key = (full_cls_name, layer_kv_cache_spec)
5728
5729
5730
                attn_backends[key] = AttentionGroupKey(
                    attn_backend, layer_kv_cache_spec
                )
5731
                attn_backend_layers[key].append(layer_name)
5732
5733
5734
5735
            return (
                {attn_backends[k]: v for k, v in attn_backend_layers.items()},
                set(group_key.attn_backend for group_key in attn_backends.values()),
            )
5736
5737

        def create_attn_groups(
5738
            attn_backends_map: dict[AttentionGroupKey, list[str]],
5739
            kv_cache_group_id: int,
5740
5741
        ) -> list[AttentionGroup]:
            attn_groups: list[AttentionGroup] = []
5742
            for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
5743
                attn_group = AttentionGroup(
5744
                    attn_backend,
5745
                    layer_names,
5746
                    kv_cache_spec,
5747
                    kv_cache_group_id,
5748
                )
5749

5750
5751
5752
                attn_groups.append(attn_group)
            return attn_groups

5753
        attention_backend_maps = []
5754
        attention_backend_list = []
5755
        for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
5756
            attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
5757
            attention_backend_maps.append(attn_backends[0])
5758
            attention_backend_list.append(attn_backends[1])
5759
5760

        # Resolve cudagraph_mode before actually initialize metadata_builders
5761
5762
5763
        self._check_and_update_cudagraph_mode(
            attention_backend_list, kv_cache_config.kv_cache_groups
        )
5764

5765
5766
5767
        # Check if attention backend supports PCP&DCP and related features.
        check_attention_cp_compatibility(self.vllm_config)

5768
5769
        for i, attn_backend_map in enumerate(attention_backend_maps):
            self.attn_groups.append(create_attn_groups(attn_backend_map, i))
5770

5771
5772
5773
5774
5775
5776
5777
5778
5779
5780
5781
5782
5783
5784
5785
    def initialize_metadata_builders(
        self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
    ) -> None:
        """
        Create the metadata builders for all KV cache groups and attn groups.
        """
        for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
            for attn_group in self.attn_groups[kv_cache_group_id]:
                attn_group.create_metadata_builders(
                    self.vllm_config,
                    self.device,
                    kernel_block_sizes[kv_cache_group_id]
                    if kv_cache_group_id < len(kernel_block_sizes)
                    else None,
                    num_metadata_builders=1
5786
5787
                    if not self.parallel_config.use_ubatching
                    else self.parallel_config.num_ubatches,
5788
                )
co63oc's avatar
co63oc committed
5789
        # Calculate reorder batch threshold (if needed)
5790
5791
        # Note (tdoublep): do this *after* constructing builders,
        # because some of them change the threshold at init time.
5792
5793
        self.calculate_reorder_batch_threshold()

5794
    def _check_and_update_cudagraph_mode(
5795
5796
5797
        self,
        attention_backends: list[set[type[AttentionBackend]]],
        kv_cache_groups: list[KVCacheGroupSpec],
5798
    ) -> None:
5799
        """
5800
        Resolve the cudagraph_mode when there are multiple attention
5801
        groups with potential conflicting CUDA graph support.
5802
5803
5804
        Then initialize the cudagraph_dispatcher based on the resolved
        cudagraph_mode.
        """
5805
        min_cg_support = AttentionCGSupport.ALWAYS
5806
        min_cg_backend_name = None
5807

5808
5809
5810
5811
5812
        for attn_backend_set, kv_cache_group in zip(
            attention_backends, kv_cache_groups
        ):
            for attn_backend in attn_backend_set:
                builder_cls = attn_backend.get_builder_cls()
5813

5814
5815
5816
5817
5818
5819
                cg_support = builder_cls.get_cudagraph_support(
                    self.vllm_config, kv_cache_group.kv_cache_spec
                )
                if cg_support.value < min_cg_support.value:
                    min_cg_support = cg_support
                    min_cg_backend_name = attn_backend.__name__
5820
5821
        # Flexible resolve the cudagraph mode
        cudagraph_mode = self.compilation_config.cudagraph_mode
5822
        assert cudagraph_mode is not None
5823
        # check cudagraph for mixed batch is supported
5824
5825
5826
5827
5828
5829
        if (
            cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
            and min_cg_support != AttentionCGSupport.ALWAYS
        ):
            msg = (
                f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
5830
                f"with {min_cg_backend_name} backend (support: "
5831
5832
                f"{min_cg_support})"
            )
5833
5834
            if min_cg_support == AttentionCGSupport.NEVER:
                # if not supported any full cudagraphs, just raise it.
5835
5836
                msg += (
                    "; please try cudagraph_mode=PIECEWISE, and "
5837
                    "make sure compilation mode is VLLM_COMPILE"
5838
                )
5839
5840
5841
5842
5843
                raise ValueError(msg)

            # attempt to resolve the full cudagraph related mode
            if self.compilation_config.splitting_ops_contain_attention():
                msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
5844
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
5845
                    CUDAGraphMode.FULL_AND_PIECEWISE
5846
                )
5847
5848
            else:
                msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
5849
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
5850
                    CUDAGraphMode.FULL_DECODE_ONLY
5851
                )
5852
5853
            logger.warning(msg)

5854
        # check that if we are doing decode full-cudagraphs it is supported
5855
5856
5857
5858
        if not envs.VLLM_USE_PIECEWISE:
            if (
                cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
                and min_cg_support == AttentionCGSupport.NEVER
5859
            ):
5860
5861
5862
5863
                msg = (
                    f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
                    f"with {min_cg_backend_name} backend (support: "
                    f"{min_cg_support})"
5864
                )
5865
5866
5867
5868
5869
5870
5871
5872
5873
5874
5875
5876
5877
5878
5879
5880
5881
5882
5883
5884
                if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
                    self.compilation_config.splitting_ops_contain_attention()
                    or self.compilation_config.use_inductor_graph_partition
                ):
                    msg += (
                        "; setting cudagraph_mode=PIECEWISE because "
                        "attention is compiled piecewise"
                    )
                    cudagraph_mode = self.compilation_config.cudagraph_mode = (
                        CUDAGraphMode.PIECEWISE
                    )
                else:
                    msg += (
                        "; setting cudagraph_mode=NONE because "
                        "attention is not compiled piecewise"
                    )
                    cudagraph_mode = self.compilation_config.cudagraph_mode = (
                        CUDAGraphMode.NONE
                    )
                logger.warning(msg)
5885

5886
5887
        # check that if we are doing spec-decode + decode full-cudagraphs it is
        # supported
5888
5889
5890
5891
5892
5893
5894
5895
        if (
            cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.uniform_decode_query_len > 1
            and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value
        ):
            msg = (
                f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
                f" with spec-decode for attention backend "
5896
                f"{min_cg_backend_name} (support: {min_cg_support})"
5897
            )
5898
5899
            if self.compilation_config.splitting_ops_contain_attention():
                msg += "; setting cudagraph_mode=PIECEWISE"
5900
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
5901
                    CUDAGraphMode.PIECEWISE
5902
                )
5903
5904
            else:
                msg += "; setting cudagraph_mode=NONE"
5905
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
5906
                    CUDAGraphMode.NONE
5907
                )
5908
5909
5910
5911
            logger.warning(msg)

        # double check that we can support full cudagraph if they are requested
        # even after automatic downgrades
5912
5913
5914
5915
5916
5917
        if (
            cudagraph_mode.has_full_cudagraphs()
            and min_cg_support == AttentionCGSupport.NEVER
        ):
            raise ValueError(
                f"CUDAGraphMode.{cudagraph_mode.name} is not "
5918
                f"supported with {min_cg_backend_name} backend ("
5919
5920
                f"support:{min_cg_support}) "
                "; please try cudagraph_mode=PIECEWISE, "
5921
                "and make sure compilation mode is VLLM_COMPILE"
5922
            )
5923

5924
5925
5926
5927
        # if we have dedicated decode cudagraphs, and spec-decode is enabled,
        # we need to adjust the cudagraph sizes to be a multiple of the uniform
        # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
        # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
5928
        # Will be removed in the near future when we have separate cudagraph capture
5929
5930
5931
5932
5933
5934
5935
5936
5937
        # sizes for decode and mixed prefill-decode.
        if (
            cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and cudagraph_mode.separate_routine()
            and self.uniform_decode_query_len > 1
        ):
            self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
                self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
            )
5938
5939
5940
5941
            capture_sizes = self.compilation_config.cudagraph_capture_sizes
            self.cudagraph_batch_sizes = (
                capture_sizes if capture_sizes is not None else []
            )
5942

5943
5944
        # Trigger cudagraph dispatching keys initialization after
        # resolved cudagraph mode.
5945
        self.compilation_config.cudagraph_mode = cudagraph_mode
5946
        self.cudagraph_dispatcher.initialize_cudagraph_keys(
5947
            cudagraph_mode, self.uniform_decode_query_len
5948
        )
5949

5950
        # Initialize eagle's cudagraph dispatcher if using eagle spec decode.
5951
5952
        if self.speculative_config and self.speculative_config.use_eagle() and hasattr(self, "drafter") \
            and get_pp_group().is_last_rank:
5953
5954
5955
            assert isinstance(self.drafter, EagleProposer)
            self.drafter.initialize_cudagraph_keys(cudagraph_mode)

5956
5957
    def calculate_reorder_batch_threshold(self) -> None:
        """
5958
5959
5960
5961
        Choose the minimum reorder batch threshold from all attention groups.
        Backends should be able to support lower threshold then what they request
        just may have a performance penalty due to that backend treating decodes
        as prefills.
5962
        """
5963
5964
        min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b)

5965
        reorder_batch_thresholds: list[int | None] = [
5966
5967
5968
            group.get_metadata_builder().reorder_batch_threshold
            for group in self._attn_group_iterator()
        ]
5969
5970
5971
5972
5973
        # If there are no attention groups (attention-free model) or no backend
        # reports a threshold, leave reordering disabled.
        if len(reorder_batch_thresholds) == 0:
            self.reorder_batch_threshold = None
            return
5974
        self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)  # type: ignore[assignment]
5975

5976
5977
5978
    @staticmethod
    def select_common_block_size(
        kv_manager_block_size: int, attn_groups: list[AttentionGroup]
5979
5980
    ) -> int:
        """
5981
5982
5983
5984
5985
        Select a block size that is supported by all backends and is a factor of
        kv_manager_block_size.

        If kv_manager_block_size is supported by all backends, return it directly.
        Otherwise, return the max supported size.
5986

5987
5988
5989
5990
5991
        Args:
            kv_manager_block_size: Block size of KV cache
            attn_groups: List of attention groups

        Returns:
5992
            The selected block size
5993
5994

        Raises:
5995
            ValueError: If no valid block size found
5996
5997
        """

王敏's avatar
王敏 committed
5998
5999
6000
6001
        #exclude indexer backend
        def _participates_in_block_size_selection(backend: type[AttentionBackend]) -> bool:
            return not getattr(backend, "exclude_from_block_size_selection", False)

6002
6003
6004
6005
6006
6007
6008
6009
        def block_size_is_supported(
            backends: list[type[AttentionBackend]], block_size: int
        ) -> bool:
            """
            Check if the block size is supported by all backends.
            """
            for backend in backends:
                is_supported = False
6010
                for supported_size in backend.get_supported_kernel_block_sizes():
6011
6012
6013
6014
6015
6016
6017
6018
6019
6020
6021
6022
                    if isinstance(supported_size, int):
                        if block_size == supported_size:
                            is_supported = True
                    elif isinstance(supported_size, MultipleOf):
                        if block_size % supported_size.base == 0:
                            is_supported = True
                    else:
                        raise ValueError(f"Unknown supported size: {supported_size}")
                if not is_supported:
                    return False
            return True

zhuwenwen's avatar
zhuwenwen committed
6023
6024
6025
6026
        all_backends = [group.backend for group in attn_groups]
        backends = [
            b for b in all_backends
            if _participates_in_block_size_selection(b)
6027
            ]
zhuwenwen's avatar
zhuwenwen committed
6028

6029
6030
6031
6032
6033
6034
6035
6036
6037
6038
6039
6040
6041
6042
6043
6044
6045

        # Case 1: if the block_size of kv cache manager is supported by all backends,
        # return it directly
        if block_size_is_supported(backends, kv_manager_block_size):
            return kv_manager_block_size

        # Case 2: otherwise, the block_size must be an `int`-format supported size of
        # at least one backend. Iterate over all `int`-format supported sizes in
        # descending order and return the first one that is supported by all backends.
        # Simple proof:
        # If the supported size b is in MultipleOf(x_i) format for all attention
        # backends i, and b a factor of kv_manager_block_size, then
        # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
        # return kv_manager_block_size in case 1.
        all_int_supported_sizes = set(
            supported_size
            for backend in backends
6046
            for supported_size in backend.get_supported_kernel_block_sizes()
6047
6048
            if isinstance(supported_size, int)
        )
6049

6050
6051
6052
6053
6054
6055
        for supported_size in sorted(all_int_supported_sizes, reverse=True):
            if kv_manager_block_size % supported_size != 0:
                continue
            if block_size_is_supported(backends, supported_size):
                return supported_size
        raise ValueError(f"No common block size for {kv_manager_block_size}. ")
6056

6057
6058
6059
    def may_reinitialize_input_batch(
        self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
    ) -> None:
6060
6061
6062
6063
6064
6065
6066
        """
        Re-initialize the input batch if the block sizes are different from
        `[self.cache_config.block_size]`. This usually happens when there
        are multiple KV cache groups.

        Args:
            kv_cache_config: The KV cache configuration.
6067
            kernel_block_sizes: The kernel block sizes for each KV cache group.
6068
6069
6070
6071
        """
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
6072
            if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
6073
        ]
6074
6075
6076
6077
6078
6079
6080
6081
6082
6083
6084
6085
6086
6087
6088
6089
6090
6091
        max_num_blocks = []
        max_model_len = max(self.max_model_len, self.max_encoder_len)
        for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
                continue
            max_num_blocks_per_req = cdiv(
                max_model_len, block_sizes[i] * get_total_cp_world_size()
            )
            if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
                mamba_blocks_per_req = (
                    max_num_blocks_per_req
                    if self.cache_config.enable_prefix_caching
                    else 1
                ) + kv_cache_group.kv_cache_spec.num_speculative_blocks
                max_num_blocks_per_req = max(
                    max_num_blocks_per_req, mamba_blocks_per_req
                )
            max_num_blocks.append(max_num_blocks_per_req)
6092
6093
6094
6095

        if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
            self.cache_config.block_size
        ]:
6096
6097
6098
            assert self.cache_config.cpu_offload_gb == 0, (
                "Cannot re-initialize the input batch when CPU weight "
                "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 "  # noqa: E501
6099
6100
                "for more details."
            )
6101
6102
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
6103
                max_model_len=max_model_len,
6104
6105
6106
6107
6108
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
                block_sizes=block_sizes,
6109
                kernel_block_sizes=kernel_block_sizes,
6110
                max_num_blocks_per_req=max_num_blocks,
6111
                is_spec_decode=bool(self.vllm_config.speculative_config),
6112
                logitsprocs=self.input_batch.logitsprocs,
6113
                logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
6114
                is_pooling_model=self.is_pooling_model,
6115
6116
            )

6117
    def _allocate_kv_cache_tensors(
6118
6119
        self, kv_cache_config: KVCacheConfig
    ) -> dict[str, torch.Tensor]:
6120
        """
6121
6122
6123
        Initializes the KV cache buffer with the correct size. The buffer needs
        to be reshaped to the desired shape before being used by the models.

6124
        Args:
6125
            kv_cache_config: The KV cache config
6126
        Returns:
6127
            dict[str, torch.Tensor]: A map between layer names to their
6128
            corresponding memory buffer for KV cache.
6129
        """
6130
6131
        kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
6132
6133
6134
            tensor = torch.zeros(
                kv_cache_tensor.size, dtype=torch.int8, device=self.device
            )
6135
6136
6137
6138
6139
            for layer_name in kv_cache_tensor.shared_by:
                kv_cache_raw_tensors[layer_name] = tensor

        layer_names = set()
        for group in kv_cache_config.kv_cache_groups:
6140
6141
6142
6143
            for layer_name in group.layer_names:
                if layer_name in self.runner_only_attn_layers:
                    continue
                layer_names.add(layer_name)
6144
6145
6146
        assert layer_names == set(kv_cache_raw_tensors.keys()), (
            "Some layers are not correctly initialized"
        )
6147
6148
        return kv_cache_raw_tensors

6149
6150
6151
    def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
        return itertools.chain.from_iterable(self.attn_groups)

6152
    def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
6153
6154
        if not self.kv_cache_config.kv_cache_groups:
            return
6155
6156
        for attn_groups in self.attn_groups:
            yield from attn_groups
6157

6158
6159
6160
6161
6162
6163
6164
6165
6166
6167
6168
6169
6170
6171
6172
    def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]:
        """
        Generate kernel_block_sizes that matches each block_size.

        For attention backends that support virtual block splitting,
        use the supported block sizes from the backend.
        For other backends (like Mamba), use the same block size (no splitting).

        Args:
            kv_cache_config: The KV cache configuration.

        Returns:
            list[int]: List of kernel block sizes for each cache group.
        """
        kernel_block_sizes = []
6173
        for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
6174
6175
6176
6177
6178
6179
            kv_cache_spec = kv_cache_group.kv_cache_spec
            if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
                # All layers in the UniformTypeKVCacheSpecs have the same type,
                # Pick an arbitrary one to dispatch.
                kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
            if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
6180
                continue
6181
            elif isinstance(kv_cache_spec, AttentionSpec):
6182
6183
6184
                # This is an attention backend that supports virtual
                # block splitting. Get the supported block sizes from
                # all backends in the group.
6185
                attn_groups = self.attn_groups[kv_cache_gid]
6186
                kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
6187
                selected_kernel_size = self.select_common_block_size(
6188
6189
6190
                    kv_manager_block_size, attn_groups
                )
                kernel_block_sizes.append(selected_kernel_size)
6191
            elif isinstance(kv_cache_spec, MambaSpec):
6192
6193
                # This is likely Mamba or other non-attention cache,
                # no splitting.
6194
                kernel_block_sizes.append(kv_cache_spec.block_size)
6195
6196
6197
6198
6199
6200
            else:
                raise NotImplementedError(
                    f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
                )
        return kernel_block_sizes

6201
6202
6203
6204
    def _reshape_kv_cache_tensors(
        self,
        kv_cache_config: KVCacheConfig,
        kv_cache_raw_tensors: dict[str, torch.Tensor],
6205
        kernel_block_sizes: list[int],
6206
    ) -> dict[str, torch.Tensor]:
6207
        """
6208
        Reshape the KV cache tensors to the desired shape and dtype.
6209

6210
        Args:
6211
6212
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer, with
6213
                correct size but uninitialized shape.
6214
            kernel_block_sizes: The kernel block sizes for each KV cache group.
6215
        Returns:
6216
            Dict[str, torch.Tensor]: A map between layer names to their
6217
6218
            corresponding memory buffer for KV cache.
        """
6219
        kv_caches: dict[str, torch.Tensor] = {}
6220
        has_attn, has_mamba = False, False
6221
6222
        for group in self._kv_cache_spec_attn_group_iterator():
            kv_cache_spec = group.kv_cache_spec
6223
            attn_backend = group.backend
6224
6225
6226
6227
            if group.kv_cache_group_id == len(kernel_block_sizes):
                # There may be a last group for layers without kv cache.
                continue
            kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
6228
            for layer_name in group.layer_names:
6229
6230
                if layer_name in self.runner_only_attn_layers:
                    continue
6231
6232
                raw_tensor = kv_cache_raw_tensors[layer_name]
                assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
6233
                num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
6234
                if isinstance(kv_cache_spec, AttentionSpec):
6235
                    has_attn = True
6236
6237
                    num_blocks_per_kv_block = (
                        kv_cache_spec.block_size // kernel_block_size
6238
6239
6240
                    )
                    kernel_num_blocks = num_blocks * num_blocks_per_kv_block

6241
                    if envs.VLLM_USE_FLASH_ATTN_PA and not self.vllm_config.model_config.use_mla:
6242
                        key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
6243
6244
                            kernel_num_blocks,
                            kernel_block_size,
6245
6246
                            kv_cache_spec.num_kv_heads,
                            kv_cache_spec.head_size,
6247
6248
                            cache_dtype_str=self.cache_config.cache_dtype,
                        )
6249
6250
6251
                        dtype = kv_cache_spec.dtype
                        try:
                            key_stride_order, value_stride_order = attn_backend.get_kv_cache_stride_order()
6252
6253
                            assert len(key_stride_order) == len(key_stride_order)
                            assert len(value_stride_order) == len(value_cache_shape)
6254
                        except (AttributeError, NotImplementedError):
6255
6256
                            key_stride_order = tuple(range(len(key_cache_shape)))
                            value_stride_order = tuple(range(len(value_cache_shape)))
6257
6258
6259
6260
6261
                        # The allocation respects the backend-defined stride order
                        # to ensure the semantic remains consistent for each
                        # backend. We first obtain the generic kv cache shape and
                        # then permute it according to the stride order which could
                        # result in a non-contiguous tensor.
6262
6263
6264
6265
                        key_cache_shape = tuple(
                            key_cache_shape[i] for i in key_stride_order)
                        value_cache_shape = tuple(
                            value_cache_shape[i] for i in value_stride_order)
6266
6267
6268
6269
6270
6271
6272
6273
6274
6275
6276
6277
6278
6279
6280
6281
6282
6283
6284
6285
6286
6287
6288
6289
6290
6291
6292
                        # Maintain original KV shape view.
                        inv_key_order = [
                            key_stride_order.index(i)
                            for i in range(len(key_stride_order))
                        ]
                        inv_value_order = [
                            value_stride_order.index(i)
                            for i in range(len(value_stride_order))
                        ]
                        
                        raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
                        total_elements = raw_tensor.numel()
                        key_elements = (key_cache_shape[0] * key_cache_shape[1] * 
                                        key_cache_shape[2] * key_cache_shape[3])
                        value_elements = (value_cache_shape[0] * value_cache_shape[1] *
                                        value_cache_shape[2] * value_cache_shape[3])

                        assert total_elements == key_elements + value_elements

                        key_cache = raw_tensor[:key_elements].view(key_cache_shape).permute(
                            *inv_key_order)
                        value_cache = raw_tensor[key_elements:].view(value_cache_shape).permute(
                            *inv_value_order)
                        kv_caches[layer_name] = (key_cache, value_cache)

                    else:
                        kv_cache_shape = attn_backend.get_kv_cache_shape(
6293
6294
                            kernel_num_blocks,
                            kernel_block_size,
6295
6296
                            kv_cache_spec.num_kv_heads,
                            kv_cache_spec.head_size,
6297
6298
                            cache_dtype_str=self.cache_config.cache_dtype,
                        )
6299
6300
                        dtype = kv_cache_spec.dtype
                        try:
6301
6302
                            kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
                            assert len(kv_cache_stride_order) == len(kv_cache_shape)
6303
                        except (AttributeError, NotImplementedError):
6304
                            kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
6305
6306
6307
6308
6309
                        # The allocation respects the backend-defined stride order
                        # to ensure the semantic remains consistent for each
                        # backend. We first obtain the generic kv cache shape and
                        # then permute it according to the stride order which could
                        # result in a non-contiguous tensor.
6310
6311
6312
                        kv_cache_shape = tuple(
                            kv_cache_shape[i] for i in kv_cache_stride_order
                        )
6313
6314
6315
6316
6317
                        # Maintain original KV shape view.
                        inv_order = [
                            kv_cache_stride_order.index(i)
                            for i in range(len(kv_cache_stride_order))
                        ]
6318
6319
6320
6321
6322
6323
                        kv_caches[layer_name] = (
                            kv_cache_raw_tensors[layer_name]
                            .view(dtype)
                            .view(kv_cache_shape)
                            .permute(*inv_order)
                        )
6324

Chen Zhang's avatar
Chen Zhang committed
6325
                elif isinstance(kv_cache_spec, MambaSpec):
6326
                    has_mamba = True
Chen Zhang's avatar
Chen Zhang committed
6327
6328
                    raw_tensor = kv_cache_raw_tensors[layer_name]
                    state_tensors = []
6329
                    storage_offset_bytes = 0
6330
                    for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
6331
6332
                        dtype_size = get_dtype_size(dtype)
                        num_element_per_page = (
6333
6334
                            kv_cache_spec.page_size_bytes // dtype_size
                        )
Chen Zhang's avatar
Chen Zhang committed
6335
                        target_shape = (num_blocks, *shape)
6336
6337
                        stride = torch.empty(target_shape).stride()
                        target_stride = (num_element_per_page, *stride[1:])
6338
                        assert storage_offset_bytes % dtype_size == 0
6339
6340
6341
6342
                        tensor = torch.as_strided(
                            raw_tensor.view(dtype),
                            size=target_shape,
                            stride=target_stride,
6343
                            storage_offset=storage_offset_bytes // dtype_size,
6344
                        )
Chen Zhang's avatar
Chen Zhang committed
6345
                        state_tensors.append(tensor)
6346
                        storage_offset_bytes += stride[0] * dtype_size
6347
6348

                    kv_caches[layer_name] = state_tensors
6349
                else:
6350
                    raise NotImplementedError
6351
6352

        if has_attn and has_mamba:
6353
            self._update_hybrid_attention_mamba_layout(kv_caches)
6354

6355
6356
        return kv_caches

6357
    def _update_hybrid_attention_mamba_layout(
6358
        self, kv_caches: dict[str, torch.Tensor]
6359
    ) -> None:
6360
        """
6361
6362
        Update the layout of attention layers from (2, num_blocks, ...) to
        (num_blocks, 2, ...).
6363
6364

        Args:
6365
            kv_caches: The KV cache buffer of each layer.
6366
6367
        """

6368
6369
        for group in self._kv_cache_spec_attn_group_iterator():
            kv_cache_spec = group.kv_cache_spec
6370
            for layer_name in group.layer_names:
6371
                kv_cache = kv_caches[layer_name]
6372
6373
6374
6375
                if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
                    assert kv_cache.shape[1] != 2, (
                        "Fail to determine whether the layout is "
                        "(2, num_blocks, ...) or (num_blocks, 2, ...) for "
6376
                        f"a tensor of shape {kv_cache.shape}"
6377
                    )
6378
                    hidden_size = kv_cache.shape[2:].numel()
6379
6380
6381
6382
                    kv_cache.as_strided_(
                        size=kv_cache.shape,
                        stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
                    )
6383

6384
    def initialize_kv_cache_tensors(
6385
        self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
6386
    ) -> dict[str, torch.Tensor]:
6387
6388
6389
6390
6391
        """
        Initialize the memory buffer for KV cache.

        Args:
            kv_cache_config: The KV cache config
6392
6393
            kernel_block_sizes: The kernel block sizes for each KV cache group.

6394
        Returns:
6395
            Dict[str, torch.Tensor]: A map between layer names to their
6396
6397
            corresponding memory buffer for KV cache.
        """
6398
6399
6400
6401
6402
6403
6404
6405
6406
6407
6408
6409
6410
6411
6412
6413
6414
6415
6416
6417
6418
6419
6420
6421

        # Try creating KV caches optimized for kv-connector transfers
        cache_dtype = self.cache_config.cache_dtype
        if self.use_uniform_kv_cache(self.attn_groups, cache_dtype):
            kv_caches, cross_layers_kv_cache, attn_backend = (
                self.allocate_uniform_kv_caches(
                    kv_cache_config,
                    self.attn_groups,
                    cache_dtype,
                    self.device,
                    kernel_block_sizes,
                )
            )
            self.cross_layers_kv_cache = cross_layers_kv_cache
            self.cross_layers_attn_backend = attn_backend
        else:
            # Fallback to the general case
            # Initialize the memory buffer for KV cache
            kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)

            # Change the memory buffer to the desired shape
            kv_caches = self._reshape_kv_cache_tensors(
                kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
            )
6422

6423
        # Set up cross-layer KV cache sharing
6424
6425
        for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
            logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
6426
6427
            kv_caches[layer_name] = kv_caches[target_layer_name]

6428
6429
6430
6431
6432
6433
6434
6435
6436
        num_attn_module = (
            2 if self.model_config.hf_config.model_type == "longcat_flash" else 1
        )
        bind_kv_cache(
            kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_caches,
            num_attn_module,
        )
6437
6438
6439
        return kv_caches

    def maybe_add_kv_sharing_layers_to_kv_cache_groups(
6440
6441
        self, kv_cache_config: KVCacheConfig
    ) -> None:
6442
6443
6444
6445
6446
6447
6448
6449
6450
6451
6452
6453
6454
6455
6456
6457
6458
6459
        """
        Add layers that re-use KV cache to KV cache group of its target layer.
        Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
        """
        if not self.shared_kv_cache_layers:
            # No cross-layer KV sharing, return
            return

        add_kv_sharing_layers_to_kv_cache_groups(
            self.shared_kv_cache_layers,
            kv_cache_config.kv_cache_groups,
            self.runner_only_attn_layers,
        )

        if self.cache_config.kv_sharing_fast_prefill:
            # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
            # similar KV sharing setups, only the layers that generate KV caches
            # are involved in the prefill phase, enabling prefill to early exit.
6460
            attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
6461
6462
            for layer_name in reversed(attn_layers):
                if layer_name in self.shared_kv_cache_layers:
6463
                    self.kv_sharing_fast_prefill_eligible_layers.add(layer_name)
6464
6465
                else:
                    break
6466

6467
6468
6469
6470
6471
6472
6473
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV
            cache size of each layer
        """
6474
        kv_cache_config = deepcopy(kv_cache_config)
6475
        self.kv_cache_config = kv_cache_config
6476
        self.may_add_encoder_only_layers_to_kv_cache_config()
6477
        self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
6478
        self.initialize_attn_backend(kv_cache_config)
6479
6480
6481
6482
6483
6484
        # The kernel block size for all KV cache groups. For example, if
        # kv_cache_manager uses block_size 256 for a given group, but the attention
        # backends for that group only supports block_size 64, we will return
        # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
        # tokens each.
        kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
6485
6486
6487
6488

        # create metadata builders
        self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)

6489
        # Reinitialize need to after initialize_attn_backend
6490
6491
6492
6493
        self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
        kv_caches = self.initialize_kv_cache_tensors(
            kv_cache_config, kernel_block_sizes
        )
6494

6495
6496
6497
6498
        if self.speculative_config and (
            self.speculative_config.use_eagle()
            or self.speculative_config.uses_draft_model()
        ):
王敏's avatar
王敏 committed
6499
6500
6501
6502
6503
            #assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
            if hasattr(self, "drafter") and isinstance(self.drafter, EagleProposer | DraftModelProposer):
                # validate all draft model layers belong to the same kv cache
                # group
                self.drafter.validate_same_kv_cache_group(kv_cache_config)
6504

Robert Shaw's avatar
Robert Shaw committed
6505
        if has_kv_transfer_group():
6506
            kv_transfer_group = get_kv_transfer_group()
6507
6508
6509
6510
6511
6512
6513
            if self.cross_layers_kv_cache is not None:
                assert self.cross_layers_attn_backend is not None
                kv_transfer_group.register_cross_layers_kv_cache(
                    self.cross_layers_kv_cache, self.cross_layers_attn_backend
                )
            else:
                kv_transfer_group.register_kv_caches(kv_caches)
6514
            kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
Robert Shaw's avatar
Robert Shaw committed
6515

6516
6517
6518
6519
6520
6521
6522
6523
6524
6525
6526
6527
6528
6529
6530
6531
6532
        if self.model_config.enable_return_routed_experts:
            self.init_routed_experts_capturer()

    def init_routed_experts_capturer(self):
        logger.info(
            "Initializing routed experts capturer, enable_return_routed_experts: %s",
            self.model_config.enable_return_routed_experts,
        )
        routed_experts_capturer = RoutedExpertsCapturer.create()
        block_size = self.cache_config.block_size
        self.max_num_kv_tokens = (
            self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups)
            + 1
        ) * block_size
        routed_experts_capturer.init_buffer(
            max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
            max_num_kv_tokens=self.max_num_kv_tokens,
6533
            vllm_config=self.vllm_config,
6534
6535
        )

6536
6537
6538
6539
6540
    def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
        """
        Add encoder-only layers to the KV cache config.
        """
        block_size = self.vllm_config.cache_config.block_size
6541
        encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
6542
6543
6544
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
        for layer_name, attn_module in attn_layers.items():
            if attn_module.attn_type == AttentionType.ENCODER_ONLY:
6545
                attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
6546
6547
6548
                    block_size=block_size,
                    num_kv_heads=attn_module.num_kv_heads,
                    head_size=attn_module.head_size,
6549
6550
                    dtype=self.kv_cache_dtype,
                )
6551
6552
6553
                encoder_only_attn_specs[attn_spec].append(layer_name)
                self.runner_only_attn_layers.add(layer_name)
        if len(encoder_only_attn_specs) > 0:
6554
6555
6556
            assert len(encoder_only_attn_specs) == 1, (
                "Only support one encoder-only attention spec now"
            )
6557
6558
            spec, layer_names = encoder_only_attn_specs.popitem()
            self.kv_cache_config.kv_cache_groups.append(
6559
6560
                KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)
            )
6561

6562
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
6563
        """
6564
        Generates the KVCacheSpec by parsing the kv cache format from each
6565
6566
        Attention module in the static forward context.
        Returns:
6567
            KVCacheSpec: A dictionary mapping layer names to their KV cache
6568
6569
            format. Layers that do not need KV cache are not included.
        """
6570
6571
        if has_ec_transfer() and get_ec_transfer().is_producer:
            return {}
6572
        kv_cache_spec: dict[str, KVCacheSpec] = {}
6573
6574
        layer_type = cast(type[Any], AttentionLayerBase)
        attn_layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
Chen Zhang's avatar
Chen Zhang committed
6575
        for layer_name, attn_module in attn_layers.items():
6576
6577
6578
            if isinstance(attn_module, Attention) and (
                kv_tgt_layer := attn_module.kv_sharing_target_layer_name
            ):
6579
6580
6581
6582
6583
6584
6585
6586
6587
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                continue
6588
6589
6590
            # Skip modules that don't need KV cache (eg encoder-only attention)
            if spec := attn_module.get_kv_cache_spec(self.vllm_config):
                kv_cache_spec[layer_name] = spec
6591

6592
        return kv_cache_spec
6593

6594
6595
6596
6597
6598
6599
6600
6601
6602
    def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
        # This is a short term mitigation for issue mentioned in
        # https://github.com/vllm-project/vllm/issues/22754.
        # `tolist` would trigger a cuda wise stream sync, which
        # would block other copy ops from other cuda streams.
        # A cuda event sync would avoid such a situation. Since
        # this is in the critical path of every single model
        # forward loop, this has caused perf issue for a disagg
        # setup.
6603
        pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]]
6604
6605
6606
6607
        pinned.copy_(sampled_token_ids, non_blocking=True)
        self.transfer_event.record()
        self.transfer_event.synchronize()
        return pinned.tolist()
6608
6609
6610
6611
6612
6613
6614
6615
6616
6617
6618
6619
6620
6621
6622
6623
6624
6625
6626
6627
6628
6629
6630
6631
6632
6633
6634
6635
6636
6637
6638
6639
6640
6641
6642
6643
6644
6645
6646
6647
6648
6649
6650
6651
6652
6653
6654
6655
6656
6657
6658
6659
6660
6661
6662
6663
6664
6665
6666
6667
6668
6669
6670
6671
6672
6673
6674
6675
6676
6677
6678
6679
6680
6681
6682
6683

    def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
        """
        Get encoder timing stats for all requests and clear the registry.

        Returns:
            Dictionary mapping request_id to stats dict.
        """
        with self._encoder_timing_lock:
            stats = {
                req_id: stats_obj.to_dict()
                for req_id, stats_obj in self.encoder_timing_registry.items()
            }
            self.encoder_timing_registry.clear()
            return stats

    @contextmanager
    def timed_encoder_operation(
        self,
        should_time: bool,
        group_lora_refs: list[tuple[str, Any]],
        current_item_idx: int,
        num_items: int,
    ):
        """
        Context manager to time encoder forward operations.

        Args:
            should_time: Whether timing is enabled
            group_lora_refs: Full list of (request_id, pos_info) tuples
            current_item_idx: Starting index for this group
            num_items: Number of items in this group
        """
        if not should_time:
            yield
            return

        group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items]
        group_request_ids = {req_id for req_id, _ in group_refs}

        torch.cuda.synchronize()
        start_time = time.perf_counter()

        try:
            yield
        finally:
            torch.cuda.synchronize()
            elapsed = time.perf_counter() - start_time

            per_request_time = elapsed / max(len(group_request_ids), 1)

            with self._encoder_timing_lock:
                for req_id in group_request_ids:
                    if req_id not in self.encoder_timing_registry:
                        self.encoder_timing_registry[req_id] = EncoderTimingStats()

                    stats = self.encoder_timing_registry[req_id]
                    stats.encoder_forward_time += per_request_time
                    stats.num_encoder_calls += 1


@dataclass
class EncoderTimingStats:
    """Per-request timing statistics for encoder forward pass."""

    encoder_forward_time: float = 0.0
    """Time spent in vision encoder forward pass (seconds)."""

    num_encoder_calls: int = 0
    """Number of times encoder was called for this request."""

    def to_dict(self) -> dict[str, float | int]:
        return {
            "encoder_forward_time": self.encoder_forward_time,
            "num_encoder_calls": self.num_encoder_calls,
        }