model_runner.py 53.4 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
NOTE: Coding style guide for this file:
This model runner is shared by all models: text and multimodal, generative
and embedding, public and private. As a result, this file must only contain
code that is common to every model. Model-specific behavior belongs in the
appropriate model-specific files.

In other words:
* Be paranoid about changing this file. It should remain stable.
* Be even more paranoid about adding new lines. It should remain minimal.

Even for shared features (for example, different parallelism modes), keep the
complexity out of this path. The less common the feature, the more it should be
hidden. Prefer utility functions defined elsewhere and call them from here,
instead of embedding feature-specific logic directly.
"""

20
import functools
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
import gc
import time
from copy import deepcopy
24
from typing import Any, NamedTuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
30
31

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
32
from vllm.distributed.parallel_state import (
33
    get_dcp_group,
34
35
36
    get_pp_group,
    prepare_communication_buffer_for_model,
)
37
from vllm.forward_context import BatchDescriptor, set_forward_context
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
40
from vllm.multimodal import MULTIMODAL_REGISTRY
41
from vllm.sequence import IntermediateTensors
42
from vllm.tasks import SupportedTask
43
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
47
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
48
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
49
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
50
from vllm.v1.worker.gpu.attn_utils import (
51
    build_slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
55
56
    get_kv_cache_spec,
    init_attn_backend,
    init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
57
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
58
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
59
60
61
62
63
64
from vllm.v1.worker.gpu.cudagraph_utils import (
    BatchExecutionDescriptor,
    ModelCudaGraphManager,
    get_uniform_token_count,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
from vllm.v1.worker.gpu.input_batch import (
    InputBatch,
    InputBuffers,
68
    combine_sampled_and_draft_tokens,
69
    expand_idx_mapping,
70
    get_num_sampled_and_rejected,
71
    post_update,
72
    post_update_pool,
73
74
    prepare_pos_seq_lens,
    prepare_prefill_inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
75
)
76
77
78
79
80
from vllm.v1.worker.gpu.kv_connector import (
    NO_OP_KV_CONNECTOR,
    KVConnector,
    get_kv_connector,
)
81
from vllm.v1.worker.gpu.lora_utils import LoraState
82
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
83
from vllm.v1.worker.gpu.model_states import init_model_state
84
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
85
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
86
from vllm.v1.worker.gpu.sample.output import SamplerOutput
87
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
88
from vllm.v1.worker.gpu.sample.sampler import Sampler
89
from vllm.v1.worker.gpu.spec_decode import init_speculator
90
91
92
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
    set_eagle3_aux_hidden_state_layers,
)
93
from vllm.v1.worker.gpu.spec_decode.rejection_sampler import RejectionSampler
94
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
95
from vllm.v1.worker.gpu.states import RequestState
96
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

logger = init_logger(__name__)


102
class GPUModelRunner(LoRAModelRunnerMixin):
103
    def __init__(self, vllm_config: VllmConfig, device: torch.device):
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.device = device
        self.dtype = self.model_config.dtype
        self.kv_cache_dtype = self.dtype
        if self.cache_config.cache_dtype != "auto":
            # Quantized KV cache.
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype
            ]

        self.vocab_size = self.model_config.get_vocab_size()
        self.max_model_len = self.model_config.max_model_len
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.max_num_reqs = self.scheduler_config.max_num_seqs
128
        self.is_encoder_decoder = self.model_config.is_encoder_decoder
129

Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
        self.use_async_scheduling = self.scheduler_config.async_scheduling
        self.output_copy_stream = torch.cuda.Stream(self.device)
        self.output_copy_event = torch.cuda.Event()

134
135
136
137
138
139
140
141
142
        # Pipeline parallelism.
        self.pp_size = self.parallel_config.pipeline_parallel_size
        self.use_pp = self.pp_size > 1
        if self.use_pp:
            self.is_first_pp_rank = get_pp_group().is_first_rank
            self.is_last_pp_rank = get_pp_group().is_last_rank
        else:
            self.is_first_pp_rank = True
            self.is_last_pp_rank = True
143
144
        # Persistent buffer for intermediate tensors (non-first PP ranks).
        self.intermediate_tensors: IntermediateTensors | None = None
145

146
147
148
149
        # Data parallelism.
        self.dp_size = self.parallel_config.data_parallel_size
        self.dp_rank = self.parallel_config.data_parallel_rank

150
151
152
153
154
155
        # Decode context parallelism.
        self.dcp_size = self.parallel_config.decode_context_parallel_size
        self.use_dcp = self.dcp_size > 1
        self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
        self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size

156
157
158
159
160
161
162
163
164
        # Multimodal
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            self.model_config
        )
        self.encoder_cache = None
        if self.supports_mm_inputs and self.is_first_pp_rank:
            self.encoder_cache = EncoderCache()

165
        # Speculative decoding.
166
        self.speculator = None
167
        self.num_speculative_steps = 0
168
        self.use_aux_hidden_state_outputs = False
169
        use_strict_rejection_sampling = False
170
171
        if self.speculative_config is not None:
            self.num_speculative_steps = self.speculative_config.num_speculative_tokens
172
173
174
175
            use_strict_rejection_sampling = (
                self.speculative_config.rejection_sample_method == "strict"
            )

176
177
178
179
180
181
182
183
184
185
186
187
            if self.is_last_pp_rank:
                self.speculator = init_speculator(self.vllm_config, self.device)

            if self.speculative_config.method == "eagle3":
                # EAGLE3 may require auxiliary hidden states from target model outputs.
                self.use_aux_hidden_state_outputs = True
                if self.pp_size > 1:
                    raise ValueError("EAGLE3 with pipeline parallel is not supported.")

        # Draft tokens propagation - for spec-dec + struct outputs.
        self.draft_tokens_handler = DraftTokensHandler(self.device)

188
189
190
191
        # Pooling models.
        self.is_pooling_model = self.model_config.runner_type == "pooling"
        self.pooling_runner: PoolingRunner | None = None

192
        # General request states.
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
196
        self.req_states = RequestState(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
197
            num_speculative_steps=self.num_speculative_steps,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
202
203
204
205
            vocab_size=self.vocab_size,
            device=self.device,
        )
        self.input_buffers = InputBuffers(
            max_num_reqs=self.max_num_reqs,
            max_num_tokens=self.max_num_tokens,
            device=self.device,
        )
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

        self.sampler: Sampler | None = None
        self.rejection_sampler: RejectionSampler | None = None
        self.prompt_logprobs_worker: PromptLogprobsWorker | None = None
        self.structured_outputs_worker: StructuredOutputsWorker | None = None
        if self.is_last_pp_rank and not self.is_pooling_model:
            # Initialize sampling-related workers.
            # These components are only set up on the last PP rank and
            # for generative (non-pooling) models.
            self.sampler = Sampler(
                max_num_reqs=self.max_num_reqs,
                vocab_size=self.vocab_size,
                device=self.device,
                req_states=self.req_states,
                logprobs_mode=self.model_config.logprobs_mode,
                num_speculative_tokens=self.num_speculative_steps + 1,
            )
            self.rejection_sampler = RejectionSampler(
                self.sampler,
                num_speculative_steps=self.num_speculative_steps,
                use_strict_rejection_sampling=use_strict_rejection_sampling,
            )
            self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
            self.structured_outputs_worker = StructuredOutputsWorker(
                max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
                vocab_size=self.vocab_size,
                device=self.device,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235

        # CUDA graphs.
236
237
        self.decode_query_len = self.num_speculative_steps + 1
        self.cudagraph_manager = ModelCudaGraphManager(
238
239
            self.vllm_config,
            self.device,
240
241
            self.compilation_config.cudagraph_mode,
            decode_query_len=self.decode_query_len,
242
        )
243
244
        # LoRA-related workers.
        self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
245
        # KV Connector if configured.
246
247
        self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

248
        # For transferring state from execute_model to subsequent sample_tokens call.
249
        self.execute_model_state: ExecuteModelState | None = None
250

251
252
253
254
    def update_max_model_len(self, max_model_len: int) -> None:
        self.max_model_len = max_model_len
        self.req_states.max_model_len = max_model_len

255
256
257
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks: list[SupportedTask] = []
        if self.model_config.runner_type == "generate":
258
            tasks.extend(self.model_state.get_supported_generation_tasks())
259
260
261
262
263
        if self.is_pooling_model:
            # Do not rely on pooling_runner here, since this information is needed
            # on the first PP rank, while pooling_runner is only initialized
            # on the last PP rank.
            tasks.extend(PoolingRunner.get_supported_tasks(self.model))
264
        return tuple(tasks)
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
267
268
269
270
271
272
273
274
275
276
277

    def load_model(self, *args, **kwargs) -> None:
        time_before_load = time.perf_counter()
        with DeviceMemoryProfiler() as m:
            model_loader = get_model_loader(self.vllm_config.load_config)
            logger.info("Loading model from scratch...")

            self.model = model_loader.load_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.model_config,
            )
            if self.lora_config:
                self.model = self.load_lora_model(
278
                    self.model, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
279
                )
280
281
282
283
284

            if self.use_aux_hidden_state_outputs:
                assert self.speculative_config is not None
                set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
            if self.speculator is not None:
285
                self.speculator.load_model(self.model)
Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
288
289
        time_after_load = time.perf_counter()

        self.model_memory_usage = m.consumed_memory
        logger.info(
290
291
            "Model loading took %s GiB and %.6f seconds",
            format_gib(m.consumed_memory),
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
294
            time_after_load - time_before_load,
        )

295
        prepare_communication_buffer_for_model(self.model)
296
        if self.speculator is not None:
297
            prepare_communication_buffer_for_model(self.speculator.model)
298

299
        # Initialize the components that require the model.
300
        self.model_state = init_model_state(
301
302
            self.vllm_config, self.model, self.encoder_cache, self.device
        )
303
        if self.is_pooling_model and self.is_last_pp_rank:
304
            self.pooling_runner = PoolingRunner(self.model)
305

306
307
308
309
310
311
312
313
314
315
316
        if not self.is_first_pp_rank:
            # For non-first PP ranks, create intermediate tensors sized
            # for the max capture size so they can be sliced per batch.
            # Save as persistent member so runtime can copy received data
            # into the same addresses that the CUDA graphs captured.
            self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=self.max_num_tokens,
                dtype=self.model_config.dtype,
                device=self.device,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
317
318
319
    def get_model(self) -> nn.Module:
        return self.model

320
321
322
323
324
    @functools.cached_property
    def main_stream(self) -> torch.cuda.Stream:
        # Cache the default CUDA stream to avoid lookup overhead.
        return torch.cuda.current_stream(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
325
326
327
328
329
330
331
332
333
334
335
    def get_kv_cache_spec(self):
        return get_kv_cache_spec(self.vllm_config)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        kv_cache_config = deepcopy(kv_cache_config)
        self.kv_cache_config = kv_cache_config
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]

336
337
338
339
340
341
342
343
344
        block_table_max_model_len = self.max_model_len
        if self.is_encoder_decoder:
            # Cross-attention block tables need to index encoder tokens
            # (e.g., Whisper ~1500), which can exceed decoder max_model_len.
            block_table_max_model_len = max(
                block_table_max_model_len,
                getattr(self.model_config.hf_config, "max_source_positions", 0),
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
347
348
        self.block_tables = BlockTables(
            block_sizes=block_sizes,
            max_num_reqs=self.max_num_reqs,
            max_num_batched_tokens=self.max_num_tokens,
349
            max_model_len=block_table_max_model_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
350
            device=self.device,
351
352
353
            cp_size=self.dcp_size,
            cp_rank=self.dcp_rank,
            cp_interleave=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
354
355
        )

356
        self.attn_backends, self.attn_groups = init_attn_backend(
357
            self.kv_cache_config, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
358
        )
359
        check_attention_cp_compatibility(self.vllm_config)
360
        if self.speculator is not None:
361
362
            # HACK(woosuk)
            self.speculator.set_attn(
363
                self.model_state,
364
365
366
                self.kv_cache_config,
                self.block_tables,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
367
368

        self.kv_caches: list[torch.Tensor] = []
369
        kv_caches_dict = init_kv_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
370
371
372
373
374
            self.kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_cache_config,
            self.attn_backends,
            self.device,
375
            self.cache_config.cache_dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
376
        )
377
378
        self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)

Woosuk Kwon's avatar
Woosuk Kwon committed
379
380
    @torch.inference_mode()
    def _dummy_run(
381
382
383
384
385
386
        self,
        num_tokens: int,
        *args,
        skip_attn: bool = True,
        uniform_decode: bool = False,
        **kwargs,
387
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
388
        # Create a dummy scheduler output.
389
        num_reqs = min(num_tokens, self.max_num_reqs)
390
        if uniform_decode:
391
392
393
394
395
396
397
398
399
400
            # HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
            # and for spec-decode with MTP we want to make sure the dummy runs use
            # 1+num_speculative_tokens we use max here, this will likely be eventually
            # changed in the worker: https://github.com/vllm-project/vllm/pull/35243
            num_tokens = max(num_tokens, self.decode_query_len)
            num_reqs = num_tokens // self.decode_query_len
            assert num_tokens % self.decode_query_len == 0
        num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
        num_tokens_per_request[-1] += num_tokens % num_reqs

401
402
        assert sum(num_tokens_per_request) == num_tokens
        num_scheduled_tokens = {
403
            f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
404
405
406
407
408
        }
        dummy_scheduler_output = SchedulerOutput.make_empty()
        dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
        dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens

409
410
411
        # Disable any use of KVConnector for dummy runs.
        self.kv_connector.set_disabled(True)

412
        # Get the intermediate tensors for the dummy run.
413
        intermediate_tensors = None
414
        if not self.is_first_pp_rank:
415
416
            assert self.intermediate_tensors is not None
            intermediate_tensors = self.intermediate_tensors[:num_tokens]
417

418
419
        # Execute the model.
        self.execute_model(
420
421
422
423
            dummy_scheduler_output,
            intermediate_tensors=intermediate_tensors,
            dummy_run=True,
            skip_attn_for_dummy_run=skip_attn,
424
        )
425
        self.kv_connector.set_disabled(False)
426
427

        # Non-last PP ranks don't produce output for sampling.
428
        if not self.is_last_pp_rank:
429
430
            return None, None

431
        assert self.execute_model_state is not None
432
433
434
435
436
437
        input_batch = self.execute_model_state.input_batch
        attn_metadata = self.execute_model_state.attn_metadata
        slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
        hidden_states = self.execute_model_state.hidden_states
        aux_hidden_states = self.execute_model_state.aux_hidden_states
        num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
438
        self.execute_model_state = None
439
440
441

        # dummy run the eagle speculator's propose to ensure DP/EP sync.
        if self.speculator is not None:
442
            assert self.sampler is not None
443
444
445
446
447
448
449
450
451
452
            mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
            if self.speculator.supports_mm_inputs:
                mm_inputs = (
                    [],
                    torch.zeros(
                        input_batch.num_tokens,
                        dtype=torch.bool,
                        device=self.device,
                    ),
                )
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
            self.speculator.propose(
                input_batch=input_batch,
                attn_metadata=attn_metadata,
                slot_mappings=slot_mappings_by_layer,
                last_hidden_states=hidden_states,
                aux_hidden_states=aux_hidden_states,
                num_sampled=torch.ones(
                    input_batch.num_reqs, dtype=torch.int32, device=self.device
                ),
                num_rejected=torch.zeros(
                    input_batch.num_reqs, dtype=torch.int32, device=self.device
                ),
                last_sampled=self.req_states.last_sampled_tokens,
                next_prefill_tokens=self.req_states.next_prefill_tokens,
                temperature=self.sampler.sampling_states.temperature.gpu,
                seeds=self.sampler.sampling_states.seeds.gpu,
                num_tokens_across_dp=num_tokens_across_dp,
                dummy_run=True,
                skip_attn_for_dummy_run=skip_attn,
472
                mm_inputs=mm_inputs,
473
474
            )

475
        assert hidden_states is not None  # Last PP rank always has hidden_states
476
        sample_hidden_states = hidden_states[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
477
478
479
        return hidden_states, sample_hidden_states

    @torch.inference_mode()
480
    def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
481
482
        num_reqs = hidden_states.shape[0]
        logits = self.model.compute_logits(hidden_states)
483
484
        dummy_input_batch = InputBatch.make_dummy(
            num_reqs, num_reqs, self.input_buffers
485
        )
486

487
488
489
        # NOTE(woosuk): During the initial memory profiling, the sampler may skip
        # top_k, top_p, and logprobs, using less GPU memory than what is possible
        # during actual execution.
490
491
        assert self.sampler is not None
        self.sampler(logits, dummy_input_batch)
Woosuk Kwon's avatar
Woosuk Kwon committed
492

493
494
495
496
497
    @torch.inference_mode()
    def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
        assert self.pooling_runner is not None
        self.pooling_runner.dummy_pooler_run(hidden_states)

Woosuk Kwon's avatar
Woosuk Kwon committed
498
499
500
    @torch.inference_mode()
    def profile_run(self) -> None:
        hidden_states, sample_hidden_states = self._dummy_run(
501
            self.max_num_tokens, skip_attn=True
Woosuk Kwon's avatar
Woosuk Kwon committed
502
        )
503

504
        # Only run sampler/pooler on last PP rank (non-last ranks return None).
505
        if self.is_last_pp_rank:
506
            assert sample_hidden_states is not None
507
508
509
510
            if self.pooling_runner is None:
                self._dummy_sampler_run(sample_hidden_states)
            else:
                self._dummy_pooler_run(hidden_states)
511

512
        torch.accelerator.synchronize()
Woosuk Kwon's avatar
Woosuk Kwon committed
513
514
515
516
        del hidden_states, sample_hidden_states
        gc.collect()

    def reset_mm_cache(self) -> None:
517
518
        if self.encoder_cache is not None:
            self.encoder_cache.reset_mm_cache()
519
520

    def reset_encoder_cache(self) -> None:
521
522
        if self.encoder_cache is not None:
            self.encoder_cache.reset_encoder_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
523
524
525
526
527

    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
        # SP is not supported yet.
        return num_scheduled_tokens

528
529
530
531
    def profile_cudagraph_memory(self) -> int:
        # NOTE(woosuk): It is TBD whether we keep this API or not.
        return 0

Woosuk Kwon's avatar
Woosuk Kwon committed
532
533
534
535
536
537
538
539
540
541
    @torch.inference_mode()
    def capture_model(self) -> int:
        if not self.cudagraph_manager.needs_capture():
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
            return 0

        start_time = time.perf_counter()
542
        gc.collect()
543
        torch.accelerator.empty_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
544
545
546
547
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        with self.maybe_setup_dummy_loras(self.lora_config):
            self.cudagraph_manager.capture(
548
549
550
                self.model,
                self.model_state,
                self.input_buffers,
551
                self.intermediate_tensors,
552
553
554
                self.block_tables,
                self.attn_groups,
                self.kv_cache_config,
555
                has_lora=self.lora_config is not None,
556
                use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
557
            )
558
            if self.speculator is not None:
559
                self.speculator.capture_model()
Woosuk Kwon's avatar
Woosuk Kwon committed
560
561
562
563
564
565
566
567
568
569
570
571
572

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info(
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
        )
        return cuda_graph_size

573
574
575
576
577
578
579
580
581
582
    def _remove_request(self, req_id: str) -> bool:
        if not self.req_states.remove_request(req_id):
            return False
        if self.encoder_cache is not None:
            self.encoder_cache.remove_request(req_id)
        if self.prompt_logprobs_worker is not None:
            self.prompt_logprobs_worker.remove_request(req_id)
        self.lora_state.remove_request(req_id)
        return True

583
    def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
584
        finished_req_ids = scheduler_output.finished_req_ids
585
586
587
        preempted_req_ids = scheduler_output.preempted_req_ids
        if preempted_req_ids:
            finished_req_ids = finished_req_ids.union(preempted_req_ids)
588
        for req_id in finished_req_ids:
589
            self._remove_request(req_id)
590

591
    def free_states(self, scheduler_output: SchedulerOutput) -> None:
592
        if self.encoder_cache is not None:
593
            for mm_hash in scheduler_output.free_encoder_mm_hashes:
594
                self.encoder_cache.free_encoder_cache(mm_hash)
Woosuk Kwon's avatar
Woosuk Kwon committed
595

596
    def add_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
597
        for new_req_data in scheduler_output.scheduled_new_reqs:
598
599
            assert new_req_data.prompt_token_ids is not None
            assert new_req_data.prefill_token_ids is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
600
            req_id = new_req_data.req_id
601
602
603
604
605
606

            # Streaming input update: request already exists from a prior
            # chunk. Remove old state so it can be cleanly re-added below
            # with the updated prompt_token_ids and mm_features.
            self._remove_request(req_id)

607
            prompt_len = len(new_req_data.prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
608
609
            self.req_states.add_request(
                req_id=req_id,
610
                prompt_len=prompt_len,
611
                all_token_ids=new_req_data.prefill_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
612
613
614
                num_computed_tokens=new_req_data.num_computed_tokens,
            )
            req_index = self.req_states.req_id_to_index[req_id]
615

616
617
            if self.encoder_cache is not None:
                self.encoder_cache.add_request(req_id, new_req_data.mm_features)
618

619
            self.model_state.add_request(req_index, new_req_data)
620
621
622
            self.block_tables.append_block_ids(
                req_index, new_req_data.block_ids, overwrite=True
            )
623
            self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
Woosuk Kwon's avatar
Woosuk Kwon committed
624

625
626
            if self.is_last_pp_rank and new_req_data.sampling_params is not None:
                assert self.sampler is not None
627
628
629
                self.sampler.add_request(
                    req_index, prompt_len, new_req_data.sampling_params
                )
630
                assert self.prompt_logprobs_worker is not None
631
632
633
634
                self.prompt_logprobs_worker.add_request(
                    req_id, req_index, new_req_data.sampling_params
                )

635
636
        if scheduler_output.scheduled_new_reqs:
            self.req_states.apply_staged_writes()
637
            self.model_state.apply_staged_writes()
638
639
        if self.sampler is not None:
            self.sampler.apply_staged_writes()
640
641

    def update_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
642
        # Add new blocks for the existing requests.
643
644
        reqs = scheduler_output.scheduled_cached_reqs
        for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
645
            if req_new_block_ids is not None:
646
                req_index = self.req_states.req_id_to_index[req_id]
647
648
649
                self.block_tables.append_block_ids(
                    req_index, req_new_block_ids, overwrite=False
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
650
651

    def prepare_inputs(
652
        self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
Woosuk Kwon's avatar
Woosuk Kwon committed
653
654
    ) -> InputBatch:
        num_tokens = scheduler_output.total_num_scheduled_tokens
655
        num_tokens_after_padding = batch_desc.num_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
656
        assert num_tokens > 0
657
658
        num_tokens_per_req = scheduler_output.num_scheduled_tokens
        num_reqs = len(num_tokens_per_req)
Woosuk Kwon's avatar
Woosuk Kwon committed
659
660
661

        # Decode first, then prefill.
        # batch_idx -> req_id
662
663
664
        req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get)  # type: ignore[arg-type]
        numtoks_iter = map(num_tokens_per_req.get, req_ids)
        num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
665

666
667
        idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
        idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
668
        idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
669

670
        # Get the number of draft tokens for each request.
671
672
        draft_tokens = scheduler_output.scheduled_spec_decode_tokens
        if not draft_tokens:
673
674
675
            # No draft token scheduled (common case).
            total_num_draft_tokens = 0
            total_num_logits = num_reqs
676
            cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
677
678
679
            cu_num_logits = torch.arange(
                num_reqs + 1, device=self.device, dtype=torch.int32
            )
680
            expanded_idx_mapping = idx_mapping
681
682
683
            expanded_local_pos = torch.zeros(
                num_reqs, dtype=torch.int32, device=self.device
            )
684
        else:
685
686
            num_draft_tokens = np.fromiter(
                (len(draft_tokens.get(req_id, ())) for req_id in req_ids),
687
                dtype=np.int32,
688
                count=num_reqs,
689
690
691
692
            )
            total_num_draft_tokens = int(num_draft_tokens.sum())
            total_num_logits = num_reqs + total_num_draft_tokens

693
694
695
696
            num_logits = num_draft_tokens + 1
            cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
            cu_num_logits_np[0] = 0
            np.cumsum(num_logits, out=cu_num_logits_np[1:])
697
            cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
698

699
            max_expand_len = self.num_speculative_steps + 1
700
            expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
701
                idx_mapping, total_num_logits, cu_num_logits, max_expand_len
702
703
            )

704
        # Get query_start_loc.
705
706
        # num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
        num_reqs_padded = batch_desc.num_reqs or num_reqs
707
708
709
        query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
        query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
710
711
        # Pad for full CUDA graph mode.
        # Some attention backends like FA3 require query_start_loc to be non-decreasing.
712
        query_start_loc_np[num_reqs + 1 :] = num_tokens
713
        async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
714
715
        query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
        query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1]
716

717
718
719
720
721
722
723
724
725
726
727
        # Get prefill tokens if any.
        if self.req_states.any_prefills(idx_mapping_np):
            prepare_prefill_inputs(
                self.input_buffers.input_ids,
                self.req_states.next_prefill_tokens,
                idx_mapping,
                query_start_loc,
                self.req_states.all_token_ids.gpu,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
728

729
730
731
        # Prepare positions and seq_lens.
        prepare_pos_seq_lens(
            idx_mapping,
732
733
            query_start_loc,
            self.req_states.num_computed_tokens.gpu,
734
735
736
            self.input_buffers.positions,
            self.input_buffers.seq_lens,
        )
737
        seq_lens = self.input_buffers.seq_lens[:num_reqs_padded]
738

739
        dcp_local_seq_lens = None
740
741
        if self.use_dcp:
            # Prepare dcp local seq_lens.
742
743
            prepare_dcp_local_seq_lens(
                self.input_buffers.dcp_local_seq_lens,
744
                self.input_buffers.seq_lens,
745
                num_reqs,
746
747
748
                self.dcp_size,
                self.dcp_rank,
                self.cp_interleave,
749
            )
750
            dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded]
751

752
        # Some input token ids are directly read from the last sampled tokens
753
754
        # and draft tokens. Also, get the logits indices to sample tokens from.
        logits_indices = combine_sampled_and_draft_tokens(
755
            self.input_buffers.input_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
756
757
            idx_mapping,
            self.req_states.last_sampled_tokens,
758
            query_start_loc,
759
760
            seq_lens,
            self.req_states.prefill_len.gpu,
761
762
763
            self.req_states.draft_tokens,
            cu_num_logits,
            total_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
764
765
766
767
768
        )

        return InputBatch(
            req_ids=req_ids,
            num_reqs=num_reqs,
769
            num_reqs_after_padding=num_reqs_padded,
Woosuk Kwon's avatar
Woosuk Kwon committed
770
771
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
772
            expanded_idx_mapping=expanded_idx_mapping,
773
            expanded_local_pos=expanded_local_pos,
Woosuk Kwon's avatar
Woosuk Kwon committed
774
775
776
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens_after_padding,
777
            num_draft_tokens=total_num_draft_tokens,
778
            query_start_loc=query_start_loc,
Woosuk Kwon's avatar
Woosuk Kwon committed
779
            query_start_loc_np=query_start_loc_np,
780
            seq_lens=seq_lens,
781
782
783
            dcp_local_seq_lens=dcp_local_seq_lens,
            input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
            positions=self.input_buffers.positions[:num_tokens_after_padding],
Woosuk Kwon's avatar
Woosuk Kwon committed
784
            logits_indices=logits_indices,
785
            cu_num_logits=cu_num_logits,
786
            cu_num_logits_np=cu_num_logits_np,
787
            has_structured_output_reqs=scheduler_output.has_structured_output_requests,
Woosuk Kwon's avatar
Woosuk Kwon committed
788
789
        )

790
791
792
    def prepare_attn(
        self, input_batch: InputBatch
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
793
794
795
796
797
798
799
        # Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
        block_tables = self.block_tables.gather_block_tables(
            input_batch.idx_mapping,
            num_reqs_padded=input_batch.num_reqs_after_padding,
        )
        # Slot mappings: [num_kv_cache_groups, num_tokens_padded].
        # Kernel pads beyond num_tokens with PAD_SLOT_ID.
800
801
802
803
        slot_mappings = self.block_tables.compute_slot_mappings(
            input_batch.idx_mapping,
            input_batch.query_start_loc,
            input_batch.positions,
804
            num_tokens_padded=input_batch.num_tokens_after_padding,
805
806
807
808
809
810
811
812
813
814
815
816
        )
        return block_tables, slot_mappings

    def prepare_dummy_attn(
        self, input_batch: InputBatch
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
        block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
        slot_mappings = self.block_tables.get_dummy_slot_mappings(
            input_batch.num_tokens
        )
        return block_tables, slot_mappings

Woosuk Kwon's avatar
Woosuk Kwon committed
817
818
819
820
821
    def sample(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
        grammar_output: GrammarOutput | None,
822
    ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
823
824
825
826
        sample_hidden_states = hidden_states[input_batch.logits_indices]
        logits = self.model.compute_logits(sample_hidden_states)
        if grammar_output is not None:
            # Apply grammar bitmask to the logits in-place.
827
            assert self.structured_outputs_worker is not None
828
829
830
831
832
833
            self.structured_outputs_worker.apply_grammar_bitmask(
                logits,
                input_batch,
                grammar_output.structured_output_request_ids,
                grammar_output.grammar_bitmask,
            )
834

835
836
        if input_batch.num_draft_tokens == 0:
            # No draft tokens (common case).
837
838
            assert self.sampler is not None
            sampler_output = self.sampler(logits, input_batch)
839
        else:
840
            # Rejection sampling for spec decoding.
841
            assert self.rejection_sampler is not None
842
            assert self.speculator is not None
843
844
845
846
            sampler_output = self.rejection_sampler(
                logits,
                input_batch,
                # Draft logits are needed for probabilistic rejection sampling.
847
                self.speculator.draft_logits,
848
            )
849
850
851
852

        # Get the number of sampled and rejected tokens.
        # For chunked prefills, num_sampled and num_rejected are both 0.
        num_sampled, num_rejected = get_num_sampled_and_rejected(
853
            sampler_output.num_sampled,
854
855
856
857
858
            input_batch.seq_lens,
            input_batch.cu_num_logits,
            input_batch.idx_mapping,
            self.req_states.prefill_len.gpu,
        )
859
        return sampler_output, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
860
861
862
863

    def postprocess(
        self,
        input_batch: InputBatch,
864
865
        sampled_tokens: torch.Tensor,
        num_sampled: torch.Tensor,
866
        num_rejected: torch.Tensor,
867
868
    ) -> None:
        # Update the number of computed tokens.
869
870
871
872
873
        if self.is_last_pp_rank:
            assert self.sampler is not None
            output_bin_counts = self.sampler.penalties_state.output_bin_counts
        else:
            output_bin_counts = None
874
        post_update(
875
            input_batch.idx_mapping,
876
            self.req_states.num_computed_tokens.gpu,
877
            self.req_states.last_sampled_tokens,
878
            output_bin_counts,
879
880
            sampled_tokens,
            num_sampled,
881
            num_rejected,
882
            input_batch.query_start_loc,
883
884
            self.req_states.all_token_ids.gpu,
            self.req_states.total_len.gpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
885
        )
886
887

        # Update the number of computed prefill tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
888
        idx_mapping_np = input_batch.idx_mapping_np
889
        computed_prefill = self.req_states.num_computed_prefill_tokens
890
891
892
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
Woosuk Kwon's avatar
Woosuk Kwon committed
893
894
895
896
897
898
        )

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: SchedulerOutput,
899
        intermediate_tensors: IntermediateTensors | None = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
900
        dummy_run: bool = False,
901
        skip_attn_for_dummy_run: bool = False,
902
    ) -> ModelRunnerOutput | IntermediateTensors | None:
903
904
905
906
907
908
909
910
911
        if not dummy_run:
            # Update the request states.
            self.finish_requests(scheduler_output)
            self.free_states(scheduler_output)
            self.add_requests(scheduler_output)
            self.update_requests(scheduler_output)
            self.block_tables.apply_staged_writes()
            if scheduler_output.total_num_scheduled_tokens == 0:
                # No need to run the model.
912
913
                empty_output = self.kv_connector.no_forward(scheduler_output)
                return empty_output
Woosuk Kwon's avatar
Woosuk Kwon committed
914

915
916
917
918
919
920
921
922
        # Get batch descriptor and sync across DP ranks.
        num_reqs = len(scheduler_output.num_scheduled_tokens)
        num_toks = scheduler_output.total_num_scheduled_tokens
        max_query_len = max(scheduler_output.num_scheduled_tokens.values())
        uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)

        batch_desc = self.cudagraph_manager.dispatch(
            num_reqs, num_toks, uniform_tok_count
923
        )
924
        num_tokens_across_dp = None
925

926
927
928
929
930
931
932
933
934
935
936
937
938
        skip_compiled = False
        if self.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
            # Encoder-decoder models such as Whisper should run eager/non-compiled
            # when encoder inputs are scheduled, because this step updates
            # cross-attention cache with dynamic encoder outputs.
            # Override batch_desc to NONE.
            skip_compiled = True
            batch_desc = BatchExecutionDescriptor(
                cg_mode=CUDAGraphMode.NONE,
                num_tokens=num_toks,
                num_reqs=num_reqs,
            )

939
940
941
942
943
944
945
946
947
        if self.dp_size > 1:
            batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
                self.cudagraph_manager,
                batch_desc,
                num_toks,
                num_reqs,
                uniform_tok_count,
                self.dp_size,
                self.dp_rank,
948
            )
949
950

        if batch_desc.num_tokens == 0:
951
            # All DP ranks have zero tokens to run.
952
953
            empty_output = self.kv_connector.no_forward(scheduler_output)
            return empty_output
954
955
956
957

        if not dummy_run:
            # Common case.
            # Prepare all the inputs and copy to the input buffers.
958
            input_batch = self.prepare_inputs(scheduler_output, batch_desc)
959
960
            block_tables, slot_mappings = self.prepare_attn(input_batch)

961
962
            if self.lora_config:
                # Activate LoRA adapters.
963
                lora_inputs = self.lora_state.make_lora_inputs(
964
965
966
                    input_batch.req_ids,
                    input_batch.idx_mapping_np,
                    input_batch.num_scheduled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
967
                )
968
969
                self._set_active_loras(*lora_inputs)
        else:
970
            # No actual tokens to run. A dummy run for DP or memory profiling.
971
            input_batch = InputBatch.make_dummy(
972
973
974
                batch_desc.num_reqs or num_reqs,
                batch_desc.num_tokens,
                self.input_buffers,
975
            )
976
            if not skip_attn_for_dummy_run:
977
978
979
980
                block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
            else:
                block_tables = None
                slot_mappings = None
981
            # FIXME(woosuk): Fix warmup for LoRA.
Woosuk Kwon's avatar
Woosuk Kwon committed
982

983
984
985
986
987
988
989
990
991
992
        attn_metadata = None
        slot_mappings_by_layer = None
        if not (dummy_run and skip_attn_for_dummy_run):
            assert slot_mappings is not None
            slot_mappings_by_layer = build_slot_mappings_by_layer(
                slot_mappings, self.kv_cache_config
            )
            assert block_tables is not None
            attn_metadata = self.model_state.prepare_attn(
                input_batch,
993
                batch_desc.cg_mode,
994
995
996
997
998
999
                block_tables,
                slot_mappings,
                self.attn_groups,
                self.kv_cache_config,
            )

1000
        inputs_embeds = None
1001
        if self.supports_mm_inputs and self.is_first_pp_rank:
1002
1003
            # Run MM encoder (if needed) and get multimodal embeddings.
            # Only first PP rank prepares multimodal embeddings.
1004
1005
            # NOTE(woosuk): We must call get_mm_embeddings even during dummy runs
            # to obtain inputs_embeds, because the compiled model expects this input.
1006
1007
1008
1009
1010
1011
            inputs_embeds = self.model_state.get_mm_embeddings(
                scheduler_output.scheduled_encoder_inputs,
                input_batch,
                self.req_states,
            )

1012
1013
1014
        model_inputs = {
            "input_ids": input_batch.input_ids,
            "positions": input_batch.positions,
1015
            "inputs_embeds": inputs_embeds,
1016
1017
1018
1019
1020
1021
1022
1023
            # NOTE: Values returned by `prepare_inputs` will override the default
            # values above.
            **self.model_state.prepare_inputs(input_batch, self.req_states),
        }
        if not self.is_first_pp_rank:
            # Update for non-first PP ranks.
            model_inputs["input_ids"] = None
            model_inputs["inputs_embeds"] = None
1024
1025

            # Prepare the intermediate tensors.
1026
            assert intermediate_tensors is not None
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
            assert self.intermediate_tensors is not None
            n = input_batch.num_tokens_after_padding
            intermediate_tensors = IntermediateTensors(
                {
                    k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
                    for k, v in self.intermediate_tensors.tensors.items()
                },
                intermediate_tensors.kv_connector_output,
            )
            model_inputs["intermediate_tensors"] = intermediate_tensors
1037

Woosuk Kwon's avatar
Woosuk Kwon committed
1038
        # Run model.
1039
        if batch_desc.cg_mode == CUDAGraphMode.FULL:
1040
            # Use explicit cudagraph replay for FULL mode.
Woosuk Kwon's avatar
Woosuk Kwon committed
1041
1042
            # NOTE(woosuk): Here, we don't need to pass the input tensors,
            # because they are already copied to the CUDA graph input buffers.
1043
            self.kv_connector.pre_forward(scheduler_output)
1044
            model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
Woosuk Kwon's avatar
Woosuk Kwon committed
1045
        else:
1046
1047
1048
1049
1050
1051
            # For piecewise and eager mode, just call model().
            batch_descriptor = BatchDescriptor(
                num_tokens=input_batch.num_tokens_after_padding,
                has_lora=self.lora_config is not None,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
1052
            with set_forward_context(
1053
                attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
1054
1055
                self.vllm_config,
                num_tokens=input_batch.num_tokens_after_padding,
1056
                cudagraph_runtime_mode=batch_desc.cg_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
1057
                num_tokens_across_dp=num_tokens_across_dp,
1058
                batch_descriptor=batch_descriptor,
1059
                slot_mapping=slot_mappings_by_layer,
1060
                skip_compiled=skip_compiled,
Woosuk Kwon's avatar
Woosuk Kwon committed
1061
            ):
1062
                self.kv_connector.pre_forward(scheduler_output)
1063
                model_output = self.model(**model_inputs)
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078

        if self.is_last_pp_rank:
            if self.use_aux_hidden_state_outputs:
                assert isinstance(model_output, tuple)
                hidden_states, aux_hidden_states = model_output
            else:
                assert isinstance(model_output, torch.Tensor)
                hidden_states = model_output
                aux_hidden_states = None
            output_intermediate_tensors = None
        else:
            assert isinstance(model_output, IntermediateTensors)
            hidden_states = None
            aux_hidden_states = None
            output_intermediate_tensors = model_output
Woosuk Kwon's avatar
Woosuk Kwon committed
1079

1080
        kv_connector_output = self.kv_connector.post_forward(scheduler_output)
1081
1082
1083
1084
1085
1086
1087
1088
        self.execute_model_state = ExecuteModelState(
            input_batch=input_batch,
            attn_metadata=attn_metadata,
            slot_mappings_by_layer=slot_mappings_by_layer,
            hidden_states=hidden_states,
            aux_hidden_states=aux_hidden_states,
            kv_connector_output=kv_connector_output,
            num_tokens_across_dp=num_tokens_across_dp,
1089
        )
1090

1091
        if not self.is_last_pp_rank:
1092
            # Non-last PP rank: return IntermediateTensors for sending.
1093
1094
1095
            assert output_intermediate_tensors is not None
            output_intermediate_tensors.kv_connector_output = kv_connector_output
            return output_intermediate_tensors
Woosuk Kwon's avatar
Woosuk Kwon committed
1096
1097
1098
1099
        return None

    @torch.inference_mode()
    def sample_tokens(
1100
        self, grammar_output: GrammarOutput | None
1101
    ) -> AsyncOutput | ModelRunnerOutput | None:
1102
1103
1104
        if self.execute_model_state is None:
            # The prior execute_model call must have failed.
            return None
1105
1106
1107
1108
1109
1110
1111
1112

        input_batch = self.execute_model_state.input_batch
        attn_metadata = self.execute_model_state.attn_metadata
        slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
        hidden_states = self.execute_model_state.hidden_states
        aux_hidden_states = self.execute_model_state.aux_hidden_states
        kv_connector_output = self.execute_model_state.kv_connector_output
        num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
1113
        self.execute_model_state = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1114

1115
        if not self.is_last_pp_rank:
1116
1117
1118
1119
            # Non-last PP rank: hidden_states is None because this rank produced
            # IntermediateTensors instead of final hidden states. Receive the
            # sampled tokens broadcast from the last rank and update local state.
            sampled, num_sampled, num_rejected = pp_receive(
1120
                input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
1121
            )
1122
            self.postprocess(input_batch, sampled, num_sampled, num_rejected)
1123
1124
1125
            return None

        # Last rank: sample tokens
1126
        sampler_output, num_sampled, num_rejected = self.sample(
1127
            hidden_states, input_batch, grammar_output
Woosuk Kwon's avatar
Woosuk Kwon committed
1128
        )
1129
1130

        if self.use_pp:
1131
            # Broadcast to non-last PP ranks (handles spec decode multi-token).
1132
            pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
1133

1134
        assert self.prompt_logprobs_worker is not None
1135
1136
1137
1138
        prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
            self.model.compute_logits,
            hidden_states,
            input_batch,
1139
            self.req_states.all_token_ids.gpu,
1140
            self.req_states.num_computed_tokens.gpu,
1141
            self.req_states.prompt_len.np,
1142
1143
1144
            self.req_states.prefill_len.np,
            self.req_states.num_computed_prefill_tokens,
        )
1145
1146
1147
1148
1149
1150
1151
1152

        # Prepare the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            # NOTE(woosuk): req_id_to_index is unused in this model runner.
            # Only for compatibility with the existing model runner and scheduler.
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            sampled_token_ids=None,  # type: ignore
1153
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore[arg-type]
1154
            kv_connector_output=kv_connector_output,
1155
1156
1157
1158
        )
        async_output = AsyncOutput(
            model_runner_output=model_runner_output,
            sampler_output=sampler_output,
1159
            num_sampled_tokens=num_sampled,
1160
            main_stream=self.main_stream,
1161
1162
1163
1164
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )

1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
        if self.speculator is not None and self.speculator.supports_mm_inputs:
            # Get cached multimodal embeddings for draft forward.
            # NOTE: This is done here because postprocess updates
            # num_computed_prefill_tokens.
            prefill_lens = self.req_states.prefill_len.np[input_batch.idx_mapping_np]
            computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
                input_batch.idx_mapping_np
            ]
            mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
                input_batch.req_ids,
                input_batch.num_tokens,
                input_batch.num_scheduled_tokens,
                input_batch.query_start_loc_np,
                prefill_lens,
                computed_prefill_lens + 1,  # +1 to consider the skew in eagle
            )

1183
1184
1185
1186
1187
1188
        # Postprocess results and update request states.
        # NOTE: This is intentionally done after creating the AsyncOutput,
        # ensuring that `copy_event` is recorded before calling postprocess.
        # This sequencing may slightly reduce latency as async D2H copy does not
        # need to wait for the postprocess to finish.
        self.postprocess(
1189
            input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
1190
        )
1191

1192
        if self.speculator is not None:
1193
            assert self.sampler is not None
1194
            draft_tokens = self.speculator.propose(
1195
                input_batch,
1196
1197
                attn_metadata,
                slot_mappings_by_layer,
1198
                hidden_states,
1199
                aux_hidden_states,
1200
1201
                num_sampled,
                num_rejected,
1202
1203
1204
1205
                self.req_states.last_sampled_tokens,
                self.req_states.next_prefill_tokens,
                self.sampler.sampling_states.temperature.gpu,
                self.sampler.sampling_states.seeds.gpu,
1206
                num_tokens_across_dp=num_tokens_across_dp,
1207
                mm_inputs=mm_inputs,
1208
            )
1209
            self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
1210
            self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
1211
1212
1213
1214

        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()
1215
1216
1217

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.draft_tokens_handler.get_draft_tokens()
1218
1219
1220
1221
1222
1223
1224

    @torch.inference_mode()
    def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
        if self.execute_model_state is None:
            # The prior execute_model call must have failed.
            return None

1225
1226
1227
        input_batch = self.execute_model_state.input_batch
        hidden_states = self.execute_model_state.hidden_states
        kv_connector_output = self.execute_model_state.kv_connector_output
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
        self.execute_model_state = None

        if not self.is_last_pp_rank:
            self.postprocess_pool(input_batch)
            return None

        assert self.pooling_runner is not None
        pooler_output, is_valid = self.pooling_runner.pool(
            hidden_states, input_batch, self.req_states
        )
        self.postprocess_pool(input_batch)

        # Build the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            kv_connector_output=kv_connector_output,
        )
        async_output = AsyncPoolingOutput(
            model_runner_output=model_runner_output,
            pooler_output=pooler_output,
            is_valid=is_valid,
            main_stream=self.main_stream,
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )
        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()

    def postprocess_pool(self, input_batch: InputBatch) -> None:
        # Update the number of computed tokens.
        post_update_pool(
            input_batch.idx_mapping,
            self.req_states.num_computed_tokens.gpu,
            input_batch.query_start_loc,
        )

        # Update the number of computed prefill tokens.
        idx_mapping_np = input_batch.idx_mapping_np
        computed_prefill = self.req_states.num_computed_prefill_tokens
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
        )
1273
1274
1275
1276
1277
1278


class ExecuteModelState(NamedTuple):
    input_batch: InputBatch
    attn_metadata: dict[str, Any] | None
    slot_mappings_by_layer: dict[str, torch.Tensor] | None
1279
    hidden_states: torch.Tensor | None
1280
1281
1282
    aux_hidden_states: list[torch.Tensor] | None
    kv_connector_output: KVConnectorOutput | None
    num_tokens_across_dp: torch.Tensor | None