model_runner.py 45.9 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
7
8
9
10
11
12
13
import gc
import time
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
14
from vllm.distributed.parallel_state import (
15
    get_dcp_group,
16
17
18
    get_pp_group,
    prepare_communication_buffer_for_model,
)
19
from vllm.forward_context import BatchDescriptor, set_forward_context
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
from vllm.sequence import IntermediateTensors
24
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
28
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
29
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
30
from vllm.v1.worker.gpu.async_utils import AsyncOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
33
    build_slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
    get_kv_cache_spec,
    init_attn_backend,
    init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
39
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
40
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
41
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
42
from vllm.v1.worker.gpu.dp_utils import (
43
    get_cudagraph_and_dp_padding,
44
45
    make_num_tokens_across_dp,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
from vllm.v1.worker.gpu.input_batch import (
    InputBatch,
    InputBuffers,
49
    combine_sampled_and_draft_tokens,
50
    expand_idx_mapping,
51
    get_num_sampled_and_rejected,
52
    post_update,
53
54
    prepare_pos_seq_lens,
    prepare_prefill_inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
)
56
57
58
59
60
from vllm.v1.worker.gpu.kv_connector import (
    NO_OP_KV_CONNECTOR,
    KVConnector,
    get_kv_connector,
)
61
from vllm.v1.worker.gpu.lora_utils import LoraState
62
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
63
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
64
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
65
from vllm.v1.worker.gpu.sample.output import SamplerOutput
66
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
67
from vllm.v1.worker.gpu.sample.sampler import Sampler
68
from vllm.v1.worker.gpu.spec_decode import init_speculator
69
70
71
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
    set_eagle3_aux_hidden_state_layers,
)
72
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
73
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
74
from vllm.v1.worker.gpu.states import RequestState
75
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
80
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

logger = init_logger(__name__)


81
class GPUModelRunner(LoRAModelRunnerMixin):
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.device = device
        self.dtype = self.model_config.dtype
        self.kv_cache_dtype = self.dtype
        if self.cache_config.cache_dtype != "auto":
            # Quantized KV cache.
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype
            ]
        self.is_pooling_model = False

        self.vocab_size = self.model_config.get_vocab_size()
        self.max_model_len = self.model_config.max_model_len
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.max_num_reqs = self.scheduler_config.max_num_seqs
112
        self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
113

114
        # Multimodal
115
116
117
118
119
120
121
122
123
124
125
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            self.model_config
        )
        if self.supports_mm_inputs:
            self.encoder_runner = EncoderRunner(
                max_num_tokens=self.max_num_tokens,
                hidden_size=self.inputs_embeds_size,
                dtype=self.dtype,
                device=self.device,
            )
126
127
128
129
        self.uses_mrope = self.model_config.uses_mrope
        if self.uses_mrope:
            self.mrope_states = MRopeState(
                max_num_reqs=self.max_num_reqs,
130
                max_num_tokens=self.max_num_tokens,
131
132
133
134
                max_model_len=self.max_model_len,
                device=self.device,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
        self.use_async_scheduling = self.scheduler_config.async_scheduling
        self.output_copy_stream = torch.cuda.Stream(self.device)
        self.output_copy_event = torch.cuda.Event()

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        # Pipeline parallelism.
        self.pp_size = self.parallel_config.pipeline_parallel_size
        self.use_pp = self.pp_size > 1
        if self.use_pp:
            self.is_first_pp_rank = get_pp_group().is_first_rank
            self.is_last_pp_rank = get_pp_group().is_last_rank
        else:
            self.is_first_pp_rank = True
            self.is_last_pp_rank = True

        # Decode context parallelism.
        self.dcp_size = self.parallel_config.decode_context_parallel_size
        self.use_dcp = self.dcp_size > 1
        self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
        self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size

        self.speculator = None
        self.use_aux_hidden_state_outputs = False
157
158
159
        if self.speculative_config is not None:
            self.do_spec_decode = True
            self.num_speculative_steps = self.speculative_config.num_speculative_tokens
160
161
162
163
164
165
166
167
            if self.is_last_pp_rank:
                self.speculator = init_speculator(self.vllm_config, self.device)

            if self.speculative_config.method == "eagle3":
                # EAGLE3 may require auxiliary hidden states from target model outputs.
                self.use_aux_hidden_state_outputs = True
                if self.pp_size > 1:
                    raise ValueError("EAGLE3 with pipeline parallel is not supported.")
168
169
170
        else:
            self.do_spec_decode = False
            self.num_speculative_steps = 0
171
172
173
174

        # Draft tokens propagation - for spec-dec + struct outputs.
        self.draft_tokens_handler = DraftTokensHandler(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
178
        self.req_states = RequestState(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
179
            num_speculative_steps=self.num_speculative_steps,
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182
183
184
185
186
187
            vocab_size=self.vocab_size,
            device=self.device,
        )
        self.input_buffers = InputBuffers(
            max_num_reqs=self.max_num_reqs,
            max_num_tokens=self.max_num_tokens,
            device=self.device,
        )
188
189
190
191
        self.sampler = Sampler(
            max_num_reqs=self.max_num_reqs,
            vocab_size=self.vocab_size,
            device=self.device,
192
            req_states=self.req_states,
193
            logprobs_mode=self.model_config.logprobs_mode,
194
            num_speculative_tokens=self.num_speculative_steps + 1,
195
        )
196
        self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198

        # CUDA graphs.
199
        self.cudagraph_manager = CudaGraphManager(
200
201
202
203
            self.vllm_config,
            self.uses_mrope,
            self.use_aux_hidden_state_outputs,
            self.device,
204
        )
205
206
207
208
        # Structured outputs worker.
        self.structured_outputs_worker = StructuredOutputsWorker(
            max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
            vocab_size=self.vocab_size,
209
            device=self.device,
210
        )
211
212
        # LoRA-related workers.
        self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
213
        # KV Connector if configured.
214
215
        self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

216
217
218
219
    def update_max_model_len(self, max_model_len: int) -> None:
        self.max_model_len = max_model_len
        self.req_states.max_model_len = max_model_len

220
221
    @staticmethod
    def get_supported_tasks() -> tuple[str]:
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        return ("generate",)

    def load_model(self, *args, **kwargs) -> None:
        time_before_load = time.perf_counter()
        with DeviceMemoryProfiler() as m:
            model_loader = get_model_loader(self.vllm_config.load_config)
            logger.info("Loading model from scratch...")

            self.model = model_loader.load_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.model_config,
            )
            if self.lora_config:
                self.model = self.load_lora_model(
236
                    self.model, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
237
                )
238
239
240
241
242

            if self.use_aux_hidden_state_outputs:
                assert self.speculative_config is not None
                set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
            if self.speculator is not None:
243
                self.speculator.load_model(self.model)
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
247
        time_after_load = time.perf_counter()

        self.model_memory_usage = m.consumed_memory
        logger.info(
248
249
            "Model loading took %s GiB and %.6f seconds",
            format_gib(m.consumed_memory),
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
            time_after_load - time_before_load,
        )

253
254
255
256
257
258
        prepare_communication_buffer_for_model(self.model)
        if self.do_spec_decode:
            speculator_model = getattr(self.speculator, "model", None)
            if speculator_model is not None:
                prepare_communication_buffer_for_model(speculator_model)

Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
    def get_model(self) -> nn.Module:
        return self.model

262
263
264
265
266
    @functools.cached_property
    def main_stream(self) -> torch.cuda.Stream:
        # Cache the default CUDA stream to avoid lookup overhead.
        return torch.cuda.current_stream(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def get_kv_cache_spec(self):
        return get_kv_cache_spec(self.vllm_config)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        kv_cache_config = deepcopy(kv_cache_config)
        self.kv_cache_config = kv_cache_config
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]

        self.block_tables = BlockTables(
            block_sizes=block_sizes,
            max_num_reqs=self.max_num_reqs,
            max_num_batched_tokens=self.max_num_tokens,
            max_model_len=self.max_model_len,
            device=self.device,
284
285
286
            cp_size=self.dcp_size,
            cp_rank=self.dcp_rank,
            cp_interleave=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
        )

289
        self.attn_backends, self.attn_groups = init_attn_backend(
290
            self.kv_cache_config, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
291
        )
292
        check_attention_cp_compatibility(self.vllm_config)
293
        if self.speculator is not None:
294
295
296
            # HACK(woosuk)
            self.speculator.set_attn(
                self.kv_cache_config,
297
                self.attn_groups,
298
299
                self.block_tables,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301

        self.kv_caches: list[torch.Tensor] = []
302
        kv_caches_dict = init_kv_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
306
307
308
            self.kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_cache_config,
            self.attn_backends,
            self.device,
        )
309
310
        self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)

Woosuk Kwon's avatar
Woosuk Kwon committed
311
312
313
314
315
    def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
        block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
        slot_mappings = self.block_tables.get_dummy_slot_mappings(
            input_batch.num_tokens
        )
316
317
318
        slot_mappings_by_layer = build_slot_mappings_by_layer(
            slot_mappings, self.kv_cache_config
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
319
        attn_metadata = build_attn_metadata(
320
            attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
321
322
            num_reqs=input_batch.num_reqs,
            num_tokens=input_batch.num_tokens,
323
324
            query_start_loc_gpu=input_batch.query_start_loc,
            query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
325
            max_query_len=input_batch.num_scheduled_tokens.max().item(),
326
            seq_lens=input_batch.seq_lens,
327
            max_seq_len=self.max_model_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
330
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
331
            dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
        )
        input_batch.attn_metadata = attn_metadata
334
        input_batch.slot_mappings = slot_mappings_by_layer
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337

    @torch.inference_mode()
    def _dummy_run(
338
        self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
339
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
340
        # Create a dummy scheduler output.
Woosuk Kwon's avatar
Woosuk Kwon committed
341
        num_reqs = min(num_tokens, self.max_num_reqs)
342
343
344
345
        num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
        num_tokens_per_request[-1] += num_tokens % num_reqs
        assert sum(num_tokens_per_request) == num_tokens
        num_scheduled_tokens = {
346
            f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
347
348
349
350
351
        }
        dummy_scheduler_output = SchedulerOutput.make_empty()
        dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
        dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens

352
353
354
        # Disable any use of KVConnector for dummy runs.
        self.kv_connector.set_disabled(True)

355
356
        # For non-first PP ranks, create dummy intermediate_tensors.
        intermediate_tensors = None
357
        if not self.is_first_pp_rank:
358
359
360
361
362
363
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=num_tokens,
                dtype=self.model_config.dtype,
                device=self.device,
            )

364
365
        # Execute the model.
        self.execute_model(
366
367
368
369
            dummy_scheduler_output,
            intermediate_tensors=intermediate_tensors,
            dummy_run=True,
            skip_attn_for_dummy_run=skip_attn,
370
        )
371
        self.kv_connector.set_disabled(False)
372
373

        # Non-last PP ranks don't produce output for sampling.
374
        if not self.is_last_pp_rank:
375
376
            return None, None

377
        assert self.execute_model_state is not None
378
        hidden_states, _, input_batch, _ = self.execute_model_state
379
        assert hidden_states is not None  # Last PP rank always has hidden_states
380
        sample_hidden_states = hidden_states[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
381
382
383
        return hidden_states, sample_hidden_states

    @torch.inference_mode()
384
    def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
385
386
        num_reqs = hidden_states.shape[0]
        logits = self.model.compute_logits(hidden_states)
387
388
389
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
        idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
        pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
390
391
392
393
        dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
        expanded_local_pos = torch.zeros(
            num_reqs, dtype=torch.int32, device=self.device
        )
394
395
396
        # NOTE(woosuk): During the initial memory profiling, the sampler may skip
        # top_k, top_p, and logprobs, using less GPU memory than what is possible
        # during actual execution.
397
398
399
400
401
402
403
404
405
        self.sampler(
            logits,
            idx_mapping,
            idx_mapping_np,
            idx_mapping_np,
            pos,
            dummy_input_ids,
            expanded_local_pos,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
406
407
408
409

    @torch.inference_mode()
    def profile_run(self) -> None:
        hidden_states, sample_hidden_states = self._dummy_run(
410
            self.max_num_tokens, skip_attn=True
Woosuk Kwon's avatar
Woosuk Kwon committed
411
        )
412

413
        # Only run sampler on last PP rank (non-last ranks return None).
414
        if self.is_last_pp_rank:
415
416
            assert sample_hidden_states is not None
            self._dummy_sampler_run(sample_hidden_states)
417

418
            if self.speculator is not None:
419
420
421
422
423
424
425
426
427
428
                num_tokens_across_dp = make_num_tokens_across_dp(
                    self.parallel_config.data_parallel_size, self.max_num_tokens
                )
                self.speculator.run_model(
                    self.max_num_tokens,
                    attn_metadata=None,
                    slot_mappings=None,
                    num_tokens_across_dp=num_tokens_across_dp,
                )

Woosuk Kwon's avatar
Woosuk Kwon committed
429
430
431
432
433
        torch.cuda.synchronize()
        del hidden_states, sample_hidden_states
        gc.collect()

    def reset_mm_cache(self) -> None:
434
435
        if self.supports_mm_inputs:
            self.encoder_runner.reset_mm_cache()
436
437

    def reset_encoder_cache(self) -> None:
438
439
        if self.supports_mm_inputs:
            self.encoder_runner.reset_encoder_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453

    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
        # SP is not supported yet.
        return num_scheduled_tokens

    @torch.inference_mode()
    def capture_model(self) -> int:
        if not self.cudagraph_manager.needs_capture():
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
            return 0

454
455
456
457
458
459
460
461
        # TODO (zhanqiu): support CUDA graph for PP.
        if self.use_pp:
            logger.warning_once(
                "Skipping CUDA graph capture because pipeline parallel is "
                "enabled. Pipeline parallel is currently eager-only.",
            )
            return 0

Woosuk Kwon's avatar
Woosuk Kwon committed
462
        start_time = time.perf_counter()
463
        gc.collect()
464
        torch.cuda.empty_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
465
466
467
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        with self.maybe_setup_dummy_loras(self.lora_config):
468
469
470
            mrope_positions = None
            if self.uses_mrope:
                mrope_positions = self.mrope_states.mrope_positions
471
472
473
            inputs_embeds = None
            if self.supports_mm_inputs:
                inputs_embeds = self.encoder_runner.inputs_embeds
Woosuk Kwon's avatar
Woosuk Kwon committed
474
475
476
            self.cudagraph_manager.capture(
                model=self.model,
                input_buffers=self.input_buffers,
477
                mrope_positions=mrope_positions,
478
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
479
                block_tables=self.block_tables,
480
                attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
481
                kv_cache_config=self.kv_cache_config,
482
                has_lora=self.lora_config is not None,
Woosuk Kwon's avatar
Woosuk Kwon committed
483
            )
484
            if self.speculator is not None:
485
                self.speculator.capture_model()
Woosuk Kwon's avatar
Woosuk Kwon committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info(
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
        )
        return cuda_graph_size

    def warmup_for_prefill(self) -> None:
        # For FlashInfer, we would like to execute a dummy prefill run
        # to trigger JIT compilation.
        if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
            self._dummy_run(self.max_num_tokens, skip_attn=False)
            torch.cuda.synchronize()

506
    def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
507
        finished_req_ids = scheduler_output.finished_req_ids
508
509
510
        preempted_req_ids = scheduler_output.preempted_req_ids
        if preempted_req_ids:
            finished_req_ids = finished_req_ids.union(preempted_req_ids)
511
        for req_id in finished_req_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
512
            self.req_states.remove_request(req_id)
513
514
            if self.supports_mm_inputs:
                self.encoder_runner.remove_request(req_id)
515
            self.prompt_logprobs_worker.remove_request(req_id)
516
            self.lora_state.remove_request(req_id)
517

518
    def free_states(self, scheduler_output: SchedulerOutput) -> None:
519
520
521
        if self.supports_mm_inputs:
            for mm_hash in scheduler_output.free_encoder_mm_hashes:
                self.encoder_runner.free_encoder_cache(mm_hash)
Woosuk Kwon's avatar
Woosuk Kwon committed
522

523
    def add_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
524
        for new_req_data in scheduler_output.scheduled_new_reqs:
525
526
527
            assert new_req_data.prompt_token_ids is not None
            assert new_req_data.prefill_token_ids is not None
            assert new_req_data.sampling_params is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
528
            req_id = new_req_data.req_id
529
            prompt_len = len(new_req_data.prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
530
531
            self.req_states.add_request(
                req_id=req_id,
532
                prompt_len=prompt_len,
533
                all_token_ids=new_req_data.prefill_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
534
535
536
                num_computed_tokens=new_req_data.num_computed_tokens,
            )
            req_index = self.req_states.req_id_to_index[req_id]
537

538
539
540
            if self.supports_mm_inputs:
                self.encoder_runner.add_request(req_id, new_req_data.mm_features)

541
542
543
544
545
546
            # Pre-compute M-RoPE positions for prefill.
            if self.uses_mrope:
                self.mrope_states.init_prefill_mrope_positions(
                    req_index,
                    self.model,  # type: ignore
                    new_req_data.prefill_token_ids,
547
                    mm_features=new_req_data.mm_features,
548
549
                )

550
551
552
            self.block_tables.append_block_ids(
                req_index, new_req_data.block_ids, overwrite=True
            )
553
554
555
            self.sampler.add_request(
                req_index, prompt_len, new_req_data.sampling_params
            )
556
557
558
            self.prompt_logprobs_worker.add_request(
                req_id, req_index, new_req_data.sampling_params
            )
559
            self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
Woosuk Kwon's avatar
Woosuk Kwon committed
560

561
562
        if scheduler_output.scheduled_new_reqs:
            self.req_states.apply_staged_writes()
563
            self.sampler.apply_staged_writes()
564
565
566
567
            if self.uses_mrope:
                self.mrope_states.apply_staged_writes()

    def update_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
568
        # Add new blocks for the existing requests.
569
570
        reqs = scheduler_output.scheduled_cached_reqs
        for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
571
            if req_new_block_ids is not None:
572
                req_index = self.req_states.req_id_to_index[req_id]
573
574
575
                self.block_tables.append_block_ids(
                    req_index, req_new_block_ids, overwrite=False
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
576
577

    def prepare_inputs(
578
        self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
Woosuk Kwon's avatar
Woosuk Kwon committed
579
580
581
    ) -> InputBatch:
        num_tokens = scheduler_output.total_num_scheduled_tokens
        assert num_tokens > 0
582
583
        num_tokens_per_req = scheduler_output.num_scheduled_tokens
        num_reqs = len(num_tokens_per_req)
Woosuk Kwon's avatar
Woosuk Kwon committed
584
585
586

        # Decode first, then prefill.
        # batch_idx -> req_id
587
588
589
        req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get)  # type: ignore[arg-type]
        numtoks_iter = map(num_tokens_per_req.get, req_ids)
        num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
590

591
592
        idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
        idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
593
        idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
594

595
        # Get the number of draft tokens for each request.
596
597
        draft_tokens = scheduler_output.scheduled_spec_decode_tokens
        if not draft_tokens:
598
599
600
            # No draft token scheduled (common case).
            total_num_draft_tokens = 0
            total_num_logits = num_reqs
601
            cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
602
603
604
            cu_num_logits = torch.arange(
                num_reqs + 1, device=self.device, dtype=torch.int32
            )
605
            expanded_idx_mapping = idx_mapping
606
607
608
            expanded_local_pos = torch.zeros(
                num_reqs, dtype=torch.int32, device=self.device
            )
609
610
        else:
            num_draft_tokens = np.array(
611
                [len(draft_tokens.get(req_id, ())) for req_id in req_ids],
612
613
614
615
616
                dtype=np.int32,
            )
            total_num_draft_tokens = int(num_draft_tokens.sum())
            total_num_logits = num_reqs + total_num_draft_tokens

617
618
619
620
            num_logits = num_draft_tokens + 1
            cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
            cu_num_logits_np[0] = 0
            np.cumsum(num_logits, out=cu_num_logits_np[1:])
621
            cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
622

623
            max_expand_len = self.num_speculative_steps + 1
624
            expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
625
                idx_mapping, total_num_logits, cu_num_logits, max_expand_len
626
627
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
628
629
630
        # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
        block_tables = self.block_tables.gather_block_tables(idx_mapping)

631
        # Get query_start_loc.
632
633
634
        query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
        query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
635
636
        # Pad for full CUDA graph mode.
        # Some attention backends like FA3 require query_start_loc to be non-decreasing.
637
        query_start_loc_np[num_reqs + 1 :] = num_tokens
638
639
        async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)

640
641
642
        query_start_loc_np = query_start_loc_np[: num_reqs + 1]
        query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
        query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
643
        max_query_len = num_scheduled_tokens.max().item()
644

645
646
647
648
649
650
651
652
653
654
655
        # Get prefill tokens if any.
        if self.req_states.any_prefills(idx_mapping_np):
            prepare_prefill_inputs(
                self.input_buffers.input_ids,
                self.req_states.next_prefill_tokens,
                idx_mapping,
                query_start_loc,
                self.req_states.all_token_ids.gpu,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
656

657
658
659
        # Prepare positions and seq_lens.
        prepare_pos_seq_lens(
            idx_mapping,
660
661
            query_start_loc,
            self.req_states.num_computed_tokens.gpu,
662
663
664
665
666
            self.input_buffers.positions,
            self.input_buffers.seq_lens,
        )
        seq_lens = self.input_buffers.seq_lens[:num_reqs]

667
668
        if self.use_dcp:
            # Prepare dcp local seq_lens.
669
670
            prepare_dcp_local_seq_lens(
                self.input_buffers.dcp_local_seq_lens,
671
                self.input_buffers.seq_lens,
672
                num_reqs,
673
674
675
                self.dcp_size,
                self.dcp_rank,
                self.cp_interleave,
676
            )
677
        dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
678

679
680
681
682
683
684
685
686
687
        # Prepare M-RoPE positions.
        if self.uses_mrope:
            self.mrope_states.prepare_mrope_positions(
                idx_mapping,
                query_start_loc,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )

688
        # Some input token ids are directly read from the last sampled tokens
689
690
        # and draft tokens. Also, get the logits indices to sample tokens from.
        logits_indices = combine_sampled_and_draft_tokens(
691
            self.input_buffers.input_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
692
693
            idx_mapping,
            self.req_states.last_sampled_tokens,
694
            query_start_loc,
695
696
            seq_lens,
            self.req_states.prefill_len.gpu,
697
698
699
            self.req_states.draft_tokens,
            cu_num_logits,
            total_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
700
701
702
703
        )

        # Compute slot mappings: [num_kv_cache_groups, num_tokens]
        slot_mappings = self.block_tables.compute_slot_mappings(
704
705
706
            idx_mapping,
            query_start_loc,
            self.input_buffers.positions[:num_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
707
        )
708
709
710
711
        # Layer name -> slot mapping.
        slot_mappings_by_layer = build_slot_mappings_by_layer(
            slot_mappings, self.kv_cache_config
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
712
713
714

        # Layer name -> attention metadata.
        attn_metadata = build_attn_metadata(
715
            attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
716
717
            num_reqs=num_reqs,
            num_tokens=num_tokens,
718
            query_start_loc_gpu=query_start_loc,
719
            query_start_loc_cpu=query_start_loc_cpu,
720
            max_query_len=max_query_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
721
            seq_lens=self.input_buffers.seq_lens,
722
            max_seq_len=self.max_model_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
723
724
725
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
726
            dcp_local_seq_lens=dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
727
728
        )

729
        input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
730
        positions = self.input_buffers.positions[:num_tokens_after_padding]
731
732
        mrope_positions = None
        if self.uses_mrope:
733
734
            mrope_positions = self.mrope_states.mrope_positions
            mrope_positions = mrope_positions[:, :num_tokens_after_padding]
Woosuk Kwon's avatar
Woosuk Kwon committed
735
736
737
738
739
        return InputBatch(
            req_ids=req_ids,
            num_reqs=num_reqs,
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
740
            expanded_idx_mapping=expanded_idx_mapping,
741
            expanded_local_pos=expanded_local_pos,
Woosuk Kwon's avatar
Woosuk Kwon committed
742
743
744
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens_after_padding,
745
            num_draft_tokens=total_num_draft_tokens,
746
            query_start_loc=query_start_loc,
Woosuk Kwon's avatar
Woosuk Kwon committed
747
            query_start_loc_np=query_start_loc_np,
748
            seq_lens=seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
749
750
            input_ids=input_ids,
            positions=positions,
751
            mrope_positions=mrope_positions,
752
            inputs_embeds=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
753
            attn_metadata=attn_metadata,
754
            slot_mappings=slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
755
            logits_indices=logits_indices,
756
            cu_num_logits=cu_num_logits,
757
            cu_num_logits_np=cu_num_logits_np,
758
            has_structured_output_reqs=scheduler_output.has_structured_output_requests,
Woosuk Kwon's avatar
Woosuk Kwon committed
759
760
        )

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
    @torch.inference_mode()
    def get_mm_embeddings(
        self,
        scheduled_encoder_inputs: dict[str, list[int]],
        input_batch: InputBatch,
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
            scheduled_encoder_inputs
        )
        self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
        mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
            input_batch.req_ids,
            input_batch.num_tokens,
            input_batch.num_scheduled_tokens,
            input_batch.query_start_loc_np,
            self.req_states.prefill_len.np[input_batch.idx_mapping_np],
            self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
        )
        return mm_embeds, is_mm_embed

Woosuk Kwon's avatar
Woosuk Kwon committed
781
782
783
784
785
    def sample(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
        grammar_output: GrammarOutput | None,
786
    ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
787
        sample_hidden_states = hidden_states[input_batch.logits_indices]
788
        sample_pos = input_batch.positions[input_batch.logits_indices]
789
        input_ids = input_batch.input_ids[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
790
791
792
        logits = self.model.compute_logits(sample_hidden_states)
        if grammar_output is not None:
            # Apply grammar bitmask to the logits in-place.
793
794
795
796
797
798
            self.structured_outputs_worker.apply_grammar_bitmask(
                logits,
                input_batch,
                grammar_output.structured_output_request_ids,
                grammar_output.grammar_bitmask,
            )
799

800
        # Sample tokens and compute logprobs (if needed).
801
802
803
804
        sampler_output = self.sampler(
            logits,
            input_batch.expanded_idx_mapping,
            input_batch.idx_mapping_np,
805
            input_batch.cu_num_logits_np,
806
            sample_pos,
807
808
            input_ids,
            input_batch.expanded_local_pos,
809
        )
810
811
812

        if input_batch.num_draft_tokens == 0:
            # No draft tokens (common case).
813
814
815
            num_sampled = torch.ones(
                input_batch.num_reqs, dtype=torch.int32, device=self.device
            )
816
        else:
817
            # Rejection sampling for spec decoding.
818
819
820
821
822
823
824
            sampled_tokens, num_sampled = rejection_sample(
                sampler_output.sampled_token_ids,
                input_ids,
                input_batch.cu_num_logits,
                self.num_speculative_steps,
            )
            sampler_output.sampled_token_ids = sampled_tokens
825
826
827
828
829
830
831
832
833
834

        # Get the number of sampled and rejected tokens.
        # For chunked prefills, num_sampled and num_rejected are both 0.
        num_sampled, num_rejected = get_num_sampled_and_rejected(
            num_sampled,
            input_batch.seq_lens,
            input_batch.cu_num_logits,
            input_batch.idx_mapping,
            self.req_states.prefill_len.gpu,
        )
835
        return sampler_output, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
836
837
838
839

    def postprocess(
        self,
        input_batch: InputBatch,
840
841
        sampled_tokens: torch.Tensor,
        num_sampled: torch.Tensor,
842
        num_rejected: torch.Tensor,
843
844
    ) -> None:
        # Update the number of computed tokens.
845
        post_update(
846
            input_batch.idx_mapping,
847
            self.req_states.num_computed_tokens.gpu,
848
            self.req_states.last_sampled_tokens,
849
            self.sampler.penalties_state.output_bin_counts,
850
851
            sampled_tokens,
            num_sampled,
852
            num_rejected,
853
            input_batch.query_start_loc,
854
855
            self.req_states.all_token_ids.gpu,
            self.req_states.total_len.gpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
856
        )
857
858

        # Update the number of computed prefill tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
859
        idx_mapping_np = input_batch.idx_mapping_np
860
        computed_prefill = self.req_states.num_computed_prefill_tokens
861
862
863
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
Woosuk Kwon's avatar
Woosuk Kwon committed
864
865
        )

866
867
868
869
870
871
872
    @torch.inference_mode()
    def propose_draft(
        self,
        input_batch: InputBatch,
        last_hidden_states: torch.Tensor,
        aux_hidden_states: list[torch.Tensor] | None,
        num_sampled: torch.Tensor,
873
        num_rejected: torch.Tensor,
874
875
876
877
878
879
880
    ) -> torch.Tensor:
        assert self.speculator is not None
        draft_tokens = self.speculator.propose(
            input_batch,
            last_hidden_states,
            aux_hidden_states,
            num_sampled,
881
            num_rejected,
882
883
            self.req_states.last_sampled_tokens,
            self.req_states.next_prefill_tokens,
884
885
            self.sampler.sampling_states.temperature.gpu,
            self.sampler.sampling_states.seeds.gpu,
886
887
888
        )
        return draft_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
889
890
891
892
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: SchedulerOutput,
893
        intermediate_tensors: IntermediateTensors | None = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
894
        dummy_run: bool = False,
895
        skip_attn_for_dummy_run: bool = False,
896
    ) -> ModelRunnerOutput | IntermediateTensors | None:
897
898
899
900
901
902
903
904
905
        if not dummy_run:
            # Update the request states.
            self.finish_requests(scheduler_output)
            self.free_states(scheduler_output)
            self.add_requests(scheduler_output)
            self.update_requests(scheduler_output)
            self.block_tables.apply_staged_writes()
            if scheduler_output.total_num_scheduled_tokens == 0:
                # No need to run the model.
906
907
                empty_output = self.kv_connector.no_forward(scheduler_output)
                return empty_output
Woosuk Kwon's avatar
Woosuk Kwon committed
908

909
910
911
912
913
914
915
        # Get local cudagraph mode and size.
        local_cudagraph_mode, local_cudagraph_size = (
            self.cudagraph_manager.get_cudagraph_runtime_mode(
                num_reqs=len(scheduler_output.num_scheduled_tokens),
                num_tokens=scheduler_output.total_num_scheduled_tokens,
                max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
            )
916
        )
917
918
919

        # DP sync: num_tokens + cudagraph_size + cudagraph_mode
        num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
920
            get_cudagraph_and_dp_padding(
921
                scheduler_output.total_num_scheduled_tokens,
922
923
                local_cudagraph_size,
                local_cudagraph_mode.value,
924
925
                self.parallel_config.data_parallel_size,
                self.parallel_config.data_parallel_rank,
926
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
927
        )
928
        cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
929
930
        if num_tokens_after_padding == 0:
            # All DP ranks have zero tokens to run.
931
932
            empty_output = self.kv_connector.no_forward(scheduler_output)
            return empty_output
933
934
935
936
937

        if not dummy_run:
            # Common case.
            # Prepare all the inputs and copy to the input buffers.
            input_batch = self.prepare_inputs(
938
                scheduler_output, num_tokens_after_padding
939
940
941
            )
            if self.lora_config:
                # Activate LoRA adapters.
942
                lora_inputs = self.lora_state.make_lora_inputs(
943
944
945
                    input_batch.req_ids,
                    input_batch.idx_mapping_np,
                    input_batch.num_scheduled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
946
                )
947
                self._set_active_loras(*lora_inputs)
948

949
            # Only first PP rank prepares multimodal embeddings.
950
            if self.supports_mm_inputs and self.is_first_pp_rank:
951
952
953
954
955
956
957
958
959
                mm_embeds, is_mm_embed = self.get_mm_embeddings(
                    scheduler_output.scheduled_encoder_inputs, input_batch
                )
                inputs_embeds = self.encoder_runner.get_inputs_embeds(
                    self.model, input_batch.input_ids, mm_embeds, is_mm_embed
                )
                input_batch.inputs_embeds = inputs_embeds[
                    : input_batch.num_tokens_after_padding
                ]
960
        else:
961
            # No actual tokens to run. A dummy run for DP or memory profiling.
962
963
964
965
966
967
968
            num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
            input_batch = InputBatch.make_dummy(
                num_reqs=num_reqs,
                num_tokens=num_tokens_after_padding,
                input_buffers=self.input_buffers,
                device=self.device,
            )
969
970
971
972
            if self.uses_mrope:
                input_batch.mrope_positions = self.mrope_states.mrope_positions[
                    :, :num_tokens_after_padding
                ]
973
974
975
            if not skip_attn_for_dummy_run:
                self.prepare_dummy_attn_metadata(input_batch)
            # FIXME(woosuk): Fix warmup for LoRA.
Woosuk Kwon's avatar
Woosuk Kwon committed
976
977

        # Run model.
978
979
        if cudagraph_runtime_mode == CUDAGraphMode.FULL:
            # Use explicit cudagraph replay for FULL mode.
Woosuk Kwon's avatar
Woosuk Kwon committed
980
981
            # NOTE(woosuk): Here, we don't need to pass the input tensors,
            # because they are already copied to the CUDA graph input buffers.
982
            self.kv_connector.pre_forward(scheduler_output)
983
            model_output = self.cudagraph_manager.run_fullgraph(
Woosuk Kwon's avatar
Woosuk Kwon committed
984
985
                input_batch.num_tokens_after_padding
            )
986
987
988
989
990
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
991
        else:
992
            # For piecewise and eager mode, just call model().
993
994
995
            positions = input_batch.positions
            if self.uses_mrope:
                assert input_batch.mrope_positions is not None
996
                positions = input_batch.mrope_positions
997

998
999
1000
1001
1002
1003
1004
1005
1006
            if self.is_first_pp_rank:
                input_ids = input_batch.input_ids
                inputs_embeds = input_batch.inputs_embeds
                assert intermediate_tensors is None
            else:
                input_ids = None
                inputs_embeds = None
                assert intermediate_tensors is not None

1007
1008
1009
1010
1011
            batch_descriptor = BatchDescriptor(
                num_tokens=input_batch.num_tokens_after_padding,
                has_lora=self.lora_config is not None,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
1012
1013
1014
1015
            with set_forward_context(
                input_batch.attn_metadata,
                self.vllm_config,
                num_tokens=input_batch.num_tokens_after_padding,
1016
                cudagraph_runtime_mode=cudagraph_runtime_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
1017
                num_tokens_across_dp=num_tokens_across_dp,
1018
                batch_descriptor=batch_descriptor,
1019
                slot_mapping=input_batch.slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
1020
            ):
1021
                self.kv_connector.pre_forward(scheduler_output)
1022
                model_output = self.model(
1023
1024
1025
1026
1027
                    input_ids=input_ids,
                    positions=positions,
                    inputs_embeds=inputs_embeds,
                    intermediate_tensors=intermediate_tensors,
                )
1028
1029
1030
1031
1032
                if self.use_aux_hidden_state_outputs:
                    hidden_states, aux_hidden_states = model_output
                else:
                    hidden_states = model_output
                    aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1033

1034
        kv_connector_output = self.kv_connector.post_forward(scheduler_output)
1035

1036
        if not self.is_last_pp_rank:
1037
1038
1039
            # Non-last PP rank: return IntermediateTensors for sending.
            assert isinstance(hidden_states, IntermediateTensors)
            hidden_states.kv_connector_output = kv_connector_output
1040
            self.execute_model_state = (None, None, input_batch, kv_connector_output)
1041
1042
1043
            return hidden_states

        # Last rank (or no PP): hidden_states is a tensor for sampling.
1044
1045
1046
1047
1048
1049
        assert isinstance(hidden_states, torch.Tensor)
        self.execute_model_state = (
            hidden_states,
            aux_hidden_states,
            input_batch,
            kv_connector_output,
1050
        )  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
1051
1052
1053
1054
        return None

    @torch.inference_mode()
    def sample_tokens(
1055
        self, grammar_output: GrammarOutput | None
1056
    ) -> AsyncOutput | ModelRunnerOutput | None:
Woosuk Kwon's avatar
Woosuk Kwon committed
1057
        assert self.execute_model_state is not None
1058
1059
1060
        hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
            self.execute_model_state
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
1061
1062
        self.execute_model_state = None  # type: ignore

1063
        if not self.is_last_pp_rank:
1064
1065
1066
1067
            # Non-last PP rank: hidden_states is None because this rank produced
            # IntermediateTensors instead of final hidden states. Receive the
            # sampled tokens broadcast from the last rank and update local state.
            sampled, num_sampled, num_rejected = pp_receive(
1068
                input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
1069
            )
1070
            self.postprocess(input_batch, sampled, num_sampled, num_rejected)
1071
1072
1073
            return None

        # Last rank: sample tokens
1074
        sampler_output, num_sampled, num_rejected = self.sample(
1075
            hidden_states, input_batch, grammar_output
Woosuk Kwon's avatar
Woosuk Kwon committed
1076
        )
1077
1078

        if self.use_pp:
1079
            # Broadcast to non-last PP ranks (handles spec decode multi-token).
1080
            pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
1081

1082
1083
1084
1085
        prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
            self.model.compute_logits,
            hidden_states,
            input_batch,
1086
            self.req_states.all_token_ids.gpu,
1087
            self.req_states.num_computed_tokens.gpu,
1088
            self.req_states.prompt_len.np,
1089
1090
1091
            self.req_states.prefill_len.np,
            self.req_states.num_computed_prefill_tokens,
        )
1092
1093
1094
1095
1096
1097
1098
1099

        # Prepare the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            # NOTE(woosuk): req_id_to_index is unused in this model runner.
            # Only for compatibility with the existing model runner and scheduler.
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            sampled_token_ids=None,  # type: ignore
1100
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore[arg-type]
1101
            kv_connector_output=kv_connector_output,
1102
1103
1104
1105
        )
        async_output = AsyncOutput(
            model_runner_output=model_runner_output,
            sampler_output=sampler_output,
1106
            num_sampled_tokens=num_sampled,
1107
            main_stream=self.main_stream,
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )

        # Postprocess results and update request states.
        # NOTE: This is intentionally done after creating the AsyncOutput,
        # ensuring that `copy_event` is recorded before calling postprocess.
        # This sequencing may slightly reduce latency as async D2H copy does not
        # need to wait for the postprocess to finish.
        self.postprocess(
1118
            input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
1119
        )
1120
        if self.speculator is not None:
1121
            draft_tokens = self.propose_draft(
1122
1123
                input_batch,
                hidden_states,
1124
                aux_hidden_states,
1125
1126
                num_sampled,
                num_rejected,
1127
            )
1128
            self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
1129
            self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
1130
1131
1132
1133

        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()
1134
1135
1136

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.draft_tokens_handler.get_draft_tokens()