"vscode:/vscode.git/clone" did not exist on "fb0acb6c72874e98617cabee4ff4851569374fc9"
model_runner.py 45.2 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
7
8
9
10
11
12
13
import gc
import time
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
14
from vllm.distributed.parallel_state import (
15
    get_dcp_group,
16
17
18
    get_pp_group,
    prepare_communication_buffer_for_model,
)
19
from vllm.forward_context import BatchDescriptor, set_forward_context
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
from vllm.sequence import IntermediateTensors
24
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
28
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
29
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
30
from vllm.v1.worker.gpu.async_utils import AsyncOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
33
    build_slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
    get_kv_cache_spec,
    init_attn_backend,
    init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
39
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
40
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
41
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
42
from vllm.v1.worker.gpu.dp_utils import (
43
    get_cudagraph_and_dp_padding,
44
45
    make_num_tokens_across_dp,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
from vllm.v1.worker.gpu.input_batch import (
    InputBatch,
    InputBuffers,
49
    combine_sampled_and_draft_tokens,
50
    expand_idx_mapping,
51
    get_num_sampled_and_rejected,
52
    post_update,
53
54
    prepare_pos_seq_lens,
    prepare_prefill_inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
)
56
57
58
59
60
from vllm.v1.worker.gpu.kv_connector import (
    NO_OP_KV_CONNECTOR,
    KVConnector,
    get_kv_connector,
)
61
from vllm.v1.worker.gpu.lora_utils import LoraState
62
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
63
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
64
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
65
from vllm.v1.worker.gpu.sample.output import SamplerOutput
66
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
67
from vllm.v1.worker.gpu.sample.sampler import Sampler
68
from vllm.v1.worker.gpu.spec_decode import init_speculator
69
70
71
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
    set_eagle3_aux_hidden_state_layers,
)
72
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
73
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
74
from vllm.v1.worker.gpu.states import RequestState
75
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
80
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

logger = init_logger(__name__)


81
class GPUModelRunner(LoRAModelRunnerMixin):
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.device = device
        self.dtype = self.model_config.dtype
        self.kv_cache_dtype = self.dtype
        if self.cache_config.cache_dtype != "auto":
            # Quantized KV cache.
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype
            ]
        self.is_pooling_model = False

        self.vocab_size = self.model_config.get_vocab_size()
        self.max_model_len = self.model_config.max_model_len
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.max_num_reqs = self.scheduler_config.max_num_seqs
112
        self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
113

114
        # Multimodal
115
116
117
118
119
120
121
122
123
124
125
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            self.model_config
        )
        if self.supports_mm_inputs:
            self.encoder_runner = EncoderRunner(
                max_num_tokens=self.max_num_tokens,
                hidden_size=self.inputs_embeds_size,
                dtype=self.dtype,
                device=self.device,
            )
126
127
128
129
        self.uses_mrope = self.model_config.uses_mrope
        if self.uses_mrope:
            self.mrope_states = MRopeState(
                max_num_reqs=self.max_num_reqs,
130
                max_num_tokens=self.max_num_tokens,
131
132
133
134
                max_model_len=self.max_model_len,
                device=self.device,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
        self.use_async_scheduling = self.scheduler_config.async_scheduling
        self.output_copy_stream = torch.cuda.Stream(self.device)
        self.output_copy_event = torch.cuda.Event()

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        # Pipeline parallelism.
        self.pp_size = self.parallel_config.pipeline_parallel_size
        self.use_pp = self.pp_size > 1
        if self.use_pp:
            self.is_first_pp_rank = get_pp_group().is_first_rank
            self.is_last_pp_rank = get_pp_group().is_last_rank
        else:
            self.is_first_pp_rank = True
            self.is_last_pp_rank = True

        # Decode context parallelism.
        self.dcp_size = self.parallel_config.decode_context_parallel_size
        self.use_dcp = self.dcp_size > 1
        self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
        self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size

        self.speculator = None
156
        self.num_speculative_steps = 0
157
        self.use_aux_hidden_state_outputs = False
158
159
        if self.speculative_config is not None:
            self.num_speculative_steps = self.speculative_config.num_speculative_tokens
160
161
162
163
164
165
166
167
168
169
170
171
            if self.is_last_pp_rank:
                self.speculator = init_speculator(self.vllm_config, self.device)

            if self.speculative_config.method == "eagle3":
                # EAGLE3 may require auxiliary hidden states from target model outputs.
                self.use_aux_hidden_state_outputs = True
                if self.pp_size > 1:
                    raise ValueError("EAGLE3 with pipeline parallel is not supported.")

        # Draft tokens propagation - for spec-dec + struct outputs.
        self.draft_tokens_handler = DraftTokensHandler(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
174
175
        self.req_states = RequestState(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
176
            num_speculative_steps=self.num_speculative_steps,
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
182
183
184
            vocab_size=self.vocab_size,
            device=self.device,
        )
        self.input_buffers = InputBuffers(
            max_num_reqs=self.max_num_reqs,
            max_num_tokens=self.max_num_tokens,
            device=self.device,
        )
185
186
187
188
        self.sampler = Sampler(
            max_num_reqs=self.max_num_reqs,
            vocab_size=self.vocab_size,
            device=self.device,
189
            req_states=self.req_states,
190
            logprobs_mode=self.model_config.logprobs_mode,
191
            num_speculative_tokens=self.num_speculative_steps + 1,
192
        )
193
        self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195

        # CUDA graphs.
196
        self.cudagraph_manager = CudaGraphManager(
197
198
199
200
            self.vllm_config,
            self.uses_mrope,
            self.use_aux_hidden_state_outputs,
            self.device,
201
        )
202
203
204
205
        # Structured outputs worker.
        self.structured_outputs_worker = StructuredOutputsWorker(
            max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
            vocab_size=self.vocab_size,
206
            device=self.device,
207
        )
208
209
        # LoRA-related workers.
        self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
210
        # KV Connector if configured.
211
212
        self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

213
214
215
216
    def update_max_model_len(self, max_model_len: int) -> None:
        self.max_model_len = max_model_len
        self.req_states.max_model_len = max_model_len

217
218
    @staticmethod
    def get_supported_tasks() -> tuple[str]:
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        return ("generate",)

    def load_model(self, *args, **kwargs) -> None:
        time_before_load = time.perf_counter()
        with DeviceMemoryProfiler() as m:
            model_loader = get_model_loader(self.vllm_config.load_config)
            logger.info("Loading model from scratch...")

            self.model = model_loader.load_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.model_config,
            )
            if self.lora_config:
                self.model = self.load_lora_model(
233
                    self.model, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
234
                )
235
236
237
238
239

            if self.use_aux_hidden_state_outputs:
                assert self.speculative_config is not None
                set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
            if self.speculator is not None:
240
                self.speculator.load_model(self.model)
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
243
244
        time_after_load = time.perf_counter()

        self.model_memory_usage = m.consumed_memory
        logger.info(
245
246
            "Model loading took %s GiB and %.6f seconds",
            format_gib(m.consumed_memory),
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
            time_after_load - time_before_load,
        )

250
        prepare_communication_buffer_for_model(self.model)
251
252
        if self.speculator is not None:
            prepare_communication_buffer_for_model(self.speculator)
253

Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
    def get_model(self) -> nn.Module:
        return self.model

257
258
259
260
261
    @functools.cached_property
    def main_stream(self) -> torch.cuda.Stream:
        # Cache the default CUDA stream to avoid lookup overhead.
        return torch.cuda.current_stream(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def get_kv_cache_spec(self):
        return get_kv_cache_spec(self.vllm_config)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        kv_cache_config = deepcopy(kv_cache_config)
        self.kv_cache_config = kv_cache_config
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]

        self.block_tables = BlockTables(
            block_sizes=block_sizes,
            max_num_reqs=self.max_num_reqs,
            max_num_batched_tokens=self.max_num_tokens,
            max_model_len=self.max_model_len,
            device=self.device,
279
280
281
            cp_size=self.dcp_size,
            cp_rank=self.dcp_rank,
            cp_interleave=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
282
283
        )

284
        self.attn_backends, self.attn_groups = init_attn_backend(
285
            self.kv_cache_config, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
286
        )
287
        check_attention_cp_compatibility(self.vllm_config)
288
        if self.speculator is not None:
289
290
291
            # HACK(woosuk)
            self.speculator.set_attn(
                self.kv_cache_config,
292
                self.attn_groups,
293
294
                self.block_tables,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296

        self.kv_caches: list[torch.Tensor] = []
297
        kv_caches_dict = init_kv_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300
301
302
303
            self.kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_cache_config,
            self.attn_backends,
            self.device,
        )
304
305
        self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)

Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
308
309
310
    def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
        block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
        slot_mappings = self.block_tables.get_dummy_slot_mappings(
            input_batch.num_tokens
        )
311
312
313
        slot_mappings_by_layer = build_slot_mappings_by_layer(
            slot_mappings, self.kv_cache_config
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
314
        attn_metadata = build_attn_metadata(
315
            attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
            num_reqs=input_batch.num_reqs,
            num_tokens=input_batch.num_tokens,
318
319
            query_start_loc_gpu=input_batch.query_start_loc,
            query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
320
            max_query_len=input_batch.num_scheduled_tokens.max().item(),
321
            seq_lens=input_batch.seq_lens,
322
            max_seq_len=self.max_model_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
323
324
325
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
326
            dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
        )
        input_batch.attn_metadata = attn_metadata
329
        input_batch.slot_mappings = slot_mappings_by_layer
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
332

    @torch.inference_mode()
    def _dummy_run(
333
        self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
334
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
335
        # Create a dummy scheduler output.
Woosuk Kwon's avatar
Woosuk Kwon committed
336
        num_reqs = min(num_tokens, self.max_num_reqs)
337
338
339
340
        num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
        num_tokens_per_request[-1] += num_tokens % num_reqs
        assert sum(num_tokens_per_request) == num_tokens
        num_scheduled_tokens = {
341
            f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
342
343
344
345
346
        }
        dummy_scheduler_output = SchedulerOutput.make_empty()
        dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
        dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens

347
348
349
        # Disable any use of KVConnector for dummy runs.
        self.kv_connector.set_disabled(True)

350
351
        # For non-first PP ranks, create dummy intermediate_tensors.
        intermediate_tensors = None
352
        if not self.is_first_pp_rank:
353
354
355
356
357
358
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=num_tokens,
                dtype=self.model_config.dtype,
                device=self.device,
            )

359
360
        # Execute the model.
        self.execute_model(
361
362
363
364
            dummy_scheduler_output,
            intermediate_tensors=intermediate_tensors,
            dummy_run=True,
            skip_attn_for_dummy_run=skip_attn,
365
        )
366
        self.kv_connector.set_disabled(False)
367
368

        # Non-last PP ranks don't produce output for sampling.
369
        if not self.is_last_pp_rank:
370
371
            return None, None

372
        assert self.execute_model_state is not None
373
        hidden_states, _, input_batch, _ = self.execute_model_state
374
        assert hidden_states is not None  # Last PP rank always has hidden_states
375
        sample_hidden_states = hidden_states[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
376
377
378
        return hidden_states, sample_hidden_states

    @torch.inference_mode()
379
    def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
380
381
        num_reqs = hidden_states.shape[0]
        logits = self.model.compute_logits(hidden_states)
382
383
384
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
        idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
        pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
385
386
387
388
        dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
        expanded_local_pos = torch.zeros(
            num_reqs, dtype=torch.int32, device=self.device
        )
389
390
391
        # NOTE(woosuk): During the initial memory profiling, the sampler may skip
        # top_k, top_p, and logprobs, using less GPU memory than what is possible
        # during actual execution.
392
393
394
395
396
397
398
399
400
        self.sampler(
            logits,
            idx_mapping,
            idx_mapping_np,
            idx_mapping_np,
            pos,
            dummy_input_ids,
            expanded_local_pos,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
403
404

    @torch.inference_mode()
    def profile_run(self) -> None:
        hidden_states, sample_hidden_states = self._dummy_run(
405
            self.max_num_tokens, skip_attn=True
Woosuk Kwon's avatar
Woosuk Kwon committed
406
        )
407

408
        # Only run sampler on last PP rank (non-last ranks return None).
409
        if self.is_last_pp_rank:
410
411
            assert sample_hidden_states is not None
            self._dummy_sampler_run(sample_hidden_states)
412

413
            if self.speculator is not None:
414
415
416
417
418
419
420
421
422
423
                num_tokens_across_dp = make_num_tokens_across_dp(
                    self.parallel_config.data_parallel_size, self.max_num_tokens
                )
                self.speculator.run_model(
                    self.max_num_tokens,
                    attn_metadata=None,
                    slot_mappings=None,
                    num_tokens_across_dp=num_tokens_across_dp,
                )

Woosuk Kwon's avatar
Woosuk Kwon committed
424
425
426
427
428
        torch.cuda.synchronize()
        del hidden_states, sample_hidden_states
        gc.collect()

    def reset_mm_cache(self) -> None:
429
430
        if self.supports_mm_inputs:
            self.encoder_runner.reset_mm_cache()
431
432

    def reset_encoder_cache(self) -> None:
433
434
        if self.supports_mm_inputs:
            self.encoder_runner.reset_encoder_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
435
436
437
438
439
440
441
442
443
444
445
446
447
448

    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
        # SP is not supported yet.
        return num_scheduled_tokens

    @torch.inference_mode()
    def capture_model(self) -> int:
        if not self.cudagraph_manager.needs_capture():
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
            return 0

449
450
451
452
453
454
455
456
        # TODO (zhanqiu): support CUDA graph for PP.
        if self.use_pp:
            logger.warning_once(
                "Skipping CUDA graph capture because pipeline parallel is "
                "enabled. Pipeline parallel is currently eager-only.",
            )
            return 0

Woosuk Kwon's avatar
Woosuk Kwon committed
457
        start_time = time.perf_counter()
458
        gc.collect()
459
        torch.cuda.empty_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
460
461
462
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        with self.maybe_setup_dummy_loras(self.lora_config):
463
464
465
            mrope_positions = None
            if self.uses_mrope:
                mrope_positions = self.mrope_states.mrope_positions
466
467
468
            inputs_embeds = None
            if self.supports_mm_inputs:
                inputs_embeds = self.encoder_runner.inputs_embeds
Woosuk Kwon's avatar
Woosuk Kwon committed
469
470
471
            self.cudagraph_manager.capture(
                model=self.model,
                input_buffers=self.input_buffers,
472
                mrope_positions=mrope_positions,
473
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
474
                block_tables=self.block_tables,
475
                attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
476
                kv_cache_config=self.kv_cache_config,
477
                has_lora=self.lora_config is not None,
Woosuk Kwon's avatar
Woosuk Kwon committed
478
            )
479
            if self.speculator is not None:
480
                self.speculator.capture_model()
Woosuk Kwon's avatar
Woosuk Kwon committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info(
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
        )
        return cuda_graph_size

    def warmup_for_prefill(self) -> None:
        # For FlashInfer, we would like to execute a dummy prefill run
        # to trigger JIT compilation.
        if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
            self._dummy_run(self.max_num_tokens, skip_attn=False)
            torch.cuda.synchronize()

501
    def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
502
        finished_req_ids = scheduler_output.finished_req_ids
503
504
505
        preempted_req_ids = scheduler_output.preempted_req_ids
        if preempted_req_ids:
            finished_req_ids = finished_req_ids.union(preempted_req_ids)
506
        for req_id in finished_req_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
507
            self.req_states.remove_request(req_id)
508
509
            if self.supports_mm_inputs:
                self.encoder_runner.remove_request(req_id)
510
            self.prompt_logprobs_worker.remove_request(req_id)
511
            self.lora_state.remove_request(req_id)
512

513
    def free_states(self, scheduler_output: SchedulerOutput) -> None:
514
515
516
        if self.supports_mm_inputs:
            for mm_hash in scheduler_output.free_encoder_mm_hashes:
                self.encoder_runner.free_encoder_cache(mm_hash)
Woosuk Kwon's avatar
Woosuk Kwon committed
517

518
    def add_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
519
        for new_req_data in scheduler_output.scheduled_new_reqs:
520
521
522
            assert new_req_data.prompt_token_ids is not None
            assert new_req_data.prefill_token_ids is not None
            assert new_req_data.sampling_params is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
523
            req_id = new_req_data.req_id
524
            prompt_len = len(new_req_data.prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
525
526
            self.req_states.add_request(
                req_id=req_id,
527
                prompt_len=prompt_len,
528
                all_token_ids=new_req_data.prefill_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
529
530
531
                num_computed_tokens=new_req_data.num_computed_tokens,
            )
            req_index = self.req_states.req_id_to_index[req_id]
532

533
534
535
            if self.supports_mm_inputs:
                self.encoder_runner.add_request(req_id, new_req_data.mm_features)

536
537
538
539
540
541
            # Pre-compute M-RoPE positions for prefill.
            if self.uses_mrope:
                self.mrope_states.init_prefill_mrope_positions(
                    req_index,
                    self.model,  # type: ignore
                    new_req_data.prefill_token_ids,
542
                    mm_features=new_req_data.mm_features,
543
544
                )

545
546
547
            self.block_tables.append_block_ids(
                req_index, new_req_data.block_ids, overwrite=True
            )
548
549
550
            self.sampler.add_request(
                req_index, prompt_len, new_req_data.sampling_params
            )
551
552
553
            self.prompt_logprobs_worker.add_request(
                req_id, req_index, new_req_data.sampling_params
            )
554
            self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
Woosuk Kwon's avatar
Woosuk Kwon committed
555

556
557
        if scheduler_output.scheduled_new_reqs:
            self.req_states.apply_staged_writes()
558
            self.sampler.apply_staged_writes()
559
560
561
562
            if self.uses_mrope:
                self.mrope_states.apply_staged_writes()

    def update_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
563
        # Add new blocks for the existing requests.
564
565
        reqs = scheduler_output.scheduled_cached_reqs
        for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
566
            if req_new_block_ids is not None:
567
                req_index = self.req_states.req_id_to_index[req_id]
568
569
570
                self.block_tables.append_block_ids(
                    req_index, req_new_block_ids, overwrite=False
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
571
572

    def prepare_inputs(
573
        self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
Woosuk Kwon's avatar
Woosuk Kwon committed
574
575
576
    ) -> InputBatch:
        num_tokens = scheduler_output.total_num_scheduled_tokens
        assert num_tokens > 0
577
578
        num_tokens_per_req = scheduler_output.num_scheduled_tokens
        num_reqs = len(num_tokens_per_req)
Woosuk Kwon's avatar
Woosuk Kwon committed
579
580
581

        # Decode first, then prefill.
        # batch_idx -> req_id
582
583
584
        req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get)  # type: ignore[arg-type]
        numtoks_iter = map(num_tokens_per_req.get, req_ids)
        num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
585

586
587
        idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
        idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
588
        idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
589

590
        # Get the number of draft tokens for each request.
591
592
        draft_tokens = scheduler_output.scheduled_spec_decode_tokens
        if not draft_tokens:
593
594
595
            # No draft token scheduled (common case).
            total_num_draft_tokens = 0
            total_num_logits = num_reqs
596
            cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
597
598
599
            cu_num_logits = torch.arange(
                num_reqs + 1, device=self.device, dtype=torch.int32
            )
600
            expanded_idx_mapping = idx_mapping
601
602
603
            expanded_local_pos = torch.zeros(
                num_reqs, dtype=torch.int32, device=self.device
            )
604
605
        else:
            num_draft_tokens = np.array(
606
                [len(draft_tokens.get(req_id, ())) for req_id in req_ids],
607
608
609
610
611
                dtype=np.int32,
            )
            total_num_draft_tokens = int(num_draft_tokens.sum())
            total_num_logits = num_reqs + total_num_draft_tokens

612
613
614
615
            num_logits = num_draft_tokens + 1
            cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
            cu_num_logits_np[0] = 0
            np.cumsum(num_logits, out=cu_num_logits_np[1:])
616
            cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
617

618
            max_expand_len = self.num_speculative_steps + 1
619
            expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
620
                idx_mapping, total_num_logits, cu_num_logits, max_expand_len
621
622
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
623
624
625
        # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
        block_tables = self.block_tables.gather_block_tables(idx_mapping)

626
        # Get query_start_loc.
627
628
629
        query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
        query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
630
631
        # Pad for full CUDA graph mode.
        # Some attention backends like FA3 require query_start_loc to be non-decreasing.
632
        query_start_loc_np[num_reqs + 1 :] = num_tokens
633
634
        async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)

635
636
637
        query_start_loc_np = query_start_loc_np[: num_reqs + 1]
        query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
        query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
638
        max_query_len = num_scheduled_tokens.max().item()
639

640
641
642
643
644
645
646
647
648
649
650
        # Get prefill tokens if any.
        if self.req_states.any_prefills(idx_mapping_np):
            prepare_prefill_inputs(
                self.input_buffers.input_ids,
                self.req_states.next_prefill_tokens,
                idx_mapping,
                query_start_loc,
                self.req_states.all_token_ids.gpu,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
651

652
653
654
        # Prepare positions and seq_lens.
        prepare_pos_seq_lens(
            idx_mapping,
655
656
            query_start_loc,
            self.req_states.num_computed_tokens.gpu,
657
658
659
660
661
            self.input_buffers.positions,
            self.input_buffers.seq_lens,
        )
        seq_lens = self.input_buffers.seq_lens[:num_reqs]

662
663
        if self.use_dcp:
            # Prepare dcp local seq_lens.
664
665
            prepare_dcp_local_seq_lens(
                self.input_buffers.dcp_local_seq_lens,
666
                self.input_buffers.seq_lens,
667
                num_reqs,
668
669
670
                self.dcp_size,
                self.dcp_rank,
                self.cp_interleave,
671
            )
672
        dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
673

674
675
676
677
678
679
680
681
682
        # Prepare M-RoPE positions.
        if self.uses_mrope:
            self.mrope_states.prepare_mrope_positions(
                idx_mapping,
                query_start_loc,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )

683
        # Some input token ids are directly read from the last sampled tokens
684
685
        # and draft tokens. Also, get the logits indices to sample tokens from.
        logits_indices = combine_sampled_and_draft_tokens(
686
            self.input_buffers.input_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
687
688
            idx_mapping,
            self.req_states.last_sampled_tokens,
689
            query_start_loc,
690
691
            seq_lens,
            self.req_states.prefill_len.gpu,
692
693
694
            self.req_states.draft_tokens,
            cu_num_logits,
            total_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
695
696
697
698
        )

        # Compute slot mappings: [num_kv_cache_groups, num_tokens]
        slot_mappings = self.block_tables.compute_slot_mappings(
699
700
701
            idx_mapping,
            query_start_loc,
            self.input_buffers.positions[:num_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
702
        )
703
704
705
706
        # Layer name -> slot mapping.
        slot_mappings_by_layer = build_slot_mappings_by_layer(
            slot_mappings, self.kv_cache_config
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
707
708
709

        # Layer name -> attention metadata.
        attn_metadata = build_attn_metadata(
710
            attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
711
712
            num_reqs=num_reqs,
            num_tokens=num_tokens,
713
            query_start_loc_gpu=query_start_loc,
714
            query_start_loc_cpu=query_start_loc_cpu,
715
            max_query_len=max_query_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
716
            seq_lens=self.input_buffers.seq_lens,
717
            max_seq_len=self.max_model_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
718
719
720
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
721
            dcp_local_seq_lens=dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
722
723
        )

724
        input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
725
        positions = self.input_buffers.positions[:num_tokens_after_padding]
726
727
        mrope_positions = None
        if self.uses_mrope:
728
729
            mrope_positions = self.mrope_states.mrope_positions
            mrope_positions = mrope_positions[:, :num_tokens_after_padding]
Woosuk Kwon's avatar
Woosuk Kwon committed
730
731
732
733
734
        return InputBatch(
            req_ids=req_ids,
            num_reqs=num_reqs,
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
735
            expanded_idx_mapping=expanded_idx_mapping,
736
            expanded_local_pos=expanded_local_pos,
Woosuk Kwon's avatar
Woosuk Kwon committed
737
738
739
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens_after_padding,
740
            num_draft_tokens=total_num_draft_tokens,
741
            query_start_loc=query_start_loc,
Woosuk Kwon's avatar
Woosuk Kwon committed
742
            query_start_loc_np=query_start_loc_np,
743
            seq_lens=seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
744
745
            input_ids=input_ids,
            positions=positions,
746
            mrope_positions=mrope_positions,
747
            inputs_embeds=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
748
            attn_metadata=attn_metadata,
749
            slot_mappings=slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
750
            logits_indices=logits_indices,
751
            cu_num_logits=cu_num_logits,
752
            cu_num_logits_np=cu_num_logits_np,
753
            has_structured_output_reqs=scheduler_output.has_structured_output_requests,
Woosuk Kwon's avatar
Woosuk Kwon committed
754
755
        )

756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
    @torch.inference_mode()
    def get_mm_embeddings(
        self,
        scheduled_encoder_inputs: dict[str, list[int]],
        input_batch: InputBatch,
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
            scheduled_encoder_inputs
        )
        self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
        mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
            input_batch.req_ids,
            input_batch.num_tokens,
            input_batch.num_scheduled_tokens,
            input_batch.query_start_loc_np,
            self.req_states.prefill_len.np[input_batch.idx_mapping_np],
            self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
        )
        return mm_embeds, is_mm_embed

Woosuk Kwon's avatar
Woosuk Kwon committed
776
777
778
779
780
    def sample(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
        grammar_output: GrammarOutput | None,
781
    ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
782
        sample_hidden_states = hidden_states[input_batch.logits_indices]
783
        sample_pos = input_batch.positions[input_batch.logits_indices]
784
        input_ids = input_batch.input_ids[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
785
786
787
        logits = self.model.compute_logits(sample_hidden_states)
        if grammar_output is not None:
            # Apply grammar bitmask to the logits in-place.
788
789
790
791
792
793
            self.structured_outputs_worker.apply_grammar_bitmask(
                logits,
                input_batch,
                grammar_output.structured_output_request_ids,
                grammar_output.grammar_bitmask,
            )
794

795
        # Sample tokens and compute logprobs (if needed).
796
797
798
799
        sampler_output = self.sampler(
            logits,
            input_batch.expanded_idx_mapping,
            input_batch.idx_mapping_np,
800
            input_batch.cu_num_logits_np,
801
            sample_pos,
802
803
            input_ids,
            input_batch.expanded_local_pos,
804
        )
805
806
807

        if input_batch.num_draft_tokens == 0:
            # No draft tokens (common case).
808
809
810
            num_sampled = torch.ones(
                input_batch.num_reqs, dtype=torch.int32, device=self.device
            )
811
        else:
812
            # Rejection sampling for spec decoding.
813
814
815
816
817
818
819
            sampled_tokens, num_sampled = rejection_sample(
                sampler_output.sampled_token_ids,
                input_ids,
                input_batch.cu_num_logits,
                self.num_speculative_steps,
            )
            sampler_output.sampled_token_ids = sampled_tokens
820
821
822
823
824
825
826
827
828
829

        # Get the number of sampled and rejected tokens.
        # For chunked prefills, num_sampled and num_rejected are both 0.
        num_sampled, num_rejected = get_num_sampled_and_rejected(
            num_sampled,
            input_batch.seq_lens,
            input_batch.cu_num_logits,
            input_batch.idx_mapping,
            self.req_states.prefill_len.gpu,
        )
830
        return sampler_output, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
831
832
833
834

    def postprocess(
        self,
        input_batch: InputBatch,
835
836
        sampled_tokens: torch.Tensor,
        num_sampled: torch.Tensor,
837
        num_rejected: torch.Tensor,
838
839
    ) -> None:
        # Update the number of computed tokens.
840
        post_update(
841
            input_batch.idx_mapping,
842
            self.req_states.num_computed_tokens.gpu,
843
            self.req_states.last_sampled_tokens,
844
            self.sampler.penalties_state.output_bin_counts,
845
846
            sampled_tokens,
            num_sampled,
847
            num_rejected,
848
            input_batch.query_start_loc,
849
850
            self.req_states.all_token_ids.gpu,
            self.req_states.total_len.gpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
851
        )
852
853

        # Update the number of computed prefill tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
854
        idx_mapping_np = input_batch.idx_mapping_np
855
        computed_prefill = self.req_states.num_computed_prefill_tokens
856
857
858
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
Woosuk Kwon's avatar
Woosuk Kwon committed
859
860
861
862
863
864
        )

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: SchedulerOutput,
865
        intermediate_tensors: IntermediateTensors | None = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
866
        dummy_run: bool = False,
867
        skip_attn_for_dummy_run: bool = False,
868
    ) -> ModelRunnerOutput | IntermediateTensors | None:
869
870
871
872
873
874
875
876
877
        if not dummy_run:
            # Update the request states.
            self.finish_requests(scheduler_output)
            self.free_states(scheduler_output)
            self.add_requests(scheduler_output)
            self.update_requests(scheduler_output)
            self.block_tables.apply_staged_writes()
            if scheduler_output.total_num_scheduled_tokens == 0:
                # No need to run the model.
878
879
                empty_output = self.kv_connector.no_forward(scheduler_output)
                return empty_output
Woosuk Kwon's avatar
Woosuk Kwon committed
880

881
882
883
884
885
886
887
        # Get local cudagraph mode and size.
        local_cudagraph_mode, local_cudagraph_size = (
            self.cudagraph_manager.get_cudagraph_runtime_mode(
                num_reqs=len(scheduler_output.num_scheduled_tokens),
                num_tokens=scheduler_output.total_num_scheduled_tokens,
                max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
            )
888
        )
889
890
891

        # DP sync: num_tokens + cudagraph_size + cudagraph_mode
        num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
892
            get_cudagraph_and_dp_padding(
893
                scheduler_output.total_num_scheduled_tokens,
894
895
                local_cudagraph_size,
                local_cudagraph_mode.value,
896
897
                self.parallel_config.data_parallel_size,
                self.parallel_config.data_parallel_rank,
898
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
899
        )
900
        cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
901
902
        if num_tokens_after_padding == 0:
            # All DP ranks have zero tokens to run.
903
904
            empty_output = self.kv_connector.no_forward(scheduler_output)
            return empty_output
905
906
907
908
909

        if not dummy_run:
            # Common case.
            # Prepare all the inputs and copy to the input buffers.
            input_batch = self.prepare_inputs(
910
                scheduler_output, num_tokens_after_padding
911
912
913
            )
            if self.lora_config:
                # Activate LoRA adapters.
914
                lora_inputs = self.lora_state.make_lora_inputs(
915
916
917
                    input_batch.req_ids,
                    input_batch.idx_mapping_np,
                    input_batch.num_scheduled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
918
                )
919
                self._set_active_loras(*lora_inputs)
920

921
            # Only first PP rank prepares multimodal embeddings.
922
            if self.supports_mm_inputs and self.is_first_pp_rank:
923
924
925
926
927
928
929
930
931
                mm_embeds, is_mm_embed = self.get_mm_embeddings(
                    scheduler_output.scheduled_encoder_inputs, input_batch
                )
                inputs_embeds = self.encoder_runner.get_inputs_embeds(
                    self.model, input_batch.input_ids, mm_embeds, is_mm_embed
                )
                input_batch.inputs_embeds = inputs_embeds[
                    : input_batch.num_tokens_after_padding
                ]
932
        else:
933
            # No actual tokens to run. A dummy run for DP or memory profiling.
934
935
936
937
938
939
940
            num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
            input_batch = InputBatch.make_dummy(
                num_reqs=num_reqs,
                num_tokens=num_tokens_after_padding,
                input_buffers=self.input_buffers,
                device=self.device,
            )
941
942
943
944
            if self.uses_mrope:
                input_batch.mrope_positions = self.mrope_states.mrope_positions[
                    :, :num_tokens_after_padding
                ]
945
946
947
            if not skip_attn_for_dummy_run:
                self.prepare_dummy_attn_metadata(input_batch)
            # FIXME(woosuk): Fix warmup for LoRA.
Woosuk Kwon's avatar
Woosuk Kwon committed
948
949

        # Run model.
950
951
        if cudagraph_runtime_mode == CUDAGraphMode.FULL:
            # Use explicit cudagraph replay for FULL mode.
Woosuk Kwon's avatar
Woosuk Kwon committed
952
953
            # NOTE(woosuk): Here, we don't need to pass the input tensors,
            # because they are already copied to the CUDA graph input buffers.
954
            self.kv_connector.pre_forward(scheduler_output)
955
            model_output = self.cudagraph_manager.run_fullgraph(
Woosuk Kwon's avatar
Woosuk Kwon committed
956
957
                input_batch.num_tokens_after_padding
            )
958
959
960
961
962
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
963
        else:
964
            # For piecewise and eager mode, just call model().
965
966
967
            positions = input_batch.positions
            if self.uses_mrope:
                assert input_batch.mrope_positions is not None
968
                positions = input_batch.mrope_positions
969

970
971
972
973
974
975
976
977
978
            if self.is_first_pp_rank:
                input_ids = input_batch.input_ids
                inputs_embeds = input_batch.inputs_embeds
                assert intermediate_tensors is None
            else:
                input_ids = None
                inputs_embeds = None
                assert intermediate_tensors is not None

979
980
981
982
983
            batch_descriptor = BatchDescriptor(
                num_tokens=input_batch.num_tokens_after_padding,
                has_lora=self.lora_config is not None,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
984
985
986
987
            with set_forward_context(
                input_batch.attn_metadata,
                self.vllm_config,
                num_tokens=input_batch.num_tokens_after_padding,
988
                cudagraph_runtime_mode=cudagraph_runtime_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
989
                num_tokens_across_dp=num_tokens_across_dp,
990
                batch_descriptor=batch_descriptor,
991
                slot_mapping=input_batch.slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
992
            ):
993
                self.kv_connector.pre_forward(scheduler_output)
994
                model_output = self.model(
995
996
997
998
999
                    input_ids=input_ids,
                    positions=positions,
                    inputs_embeds=inputs_embeds,
                    intermediate_tensors=intermediate_tensors,
                )
1000
1001
1002
1003
1004
                if self.use_aux_hidden_state_outputs:
                    hidden_states, aux_hidden_states = model_output
                else:
                    hidden_states = model_output
                    aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1005

1006
        kv_connector_output = self.kv_connector.post_forward(scheduler_output)
1007

1008
        if not self.is_last_pp_rank:
1009
1010
1011
            # Non-last PP rank: return IntermediateTensors for sending.
            assert isinstance(hidden_states, IntermediateTensors)
            hidden_states.kv_connector_output = kv_connector_output
1012
            self.execute_model_state = (None, None, input_batch, kv_connector_output)
1013
1014
1015
            return hidden_states

        # Last rank (or no PP): hidden_states is a tensor for sampling.
1016
1017
1018
1019
1020
1021
        assert isinstance(hidden_states, torch.Tensor)
        self.execute_model_state = (
            hidden_states,
            aux_hidden_states,
            input_batch,
            kv_connector_output,
1022
        )  # type: ignore
Woosuk Kwon's avatar
Woosuk Kwon committed
1023
1024
1025
1026
        return None

    @torch.inference_mode()
    def sample_tokens(
1027
        self, grammar_output: GrammarOutput | None
1028
    ) -> AsyncOutput | ModelRunnerOutput | None:
Woosuk Kwon's avatar
Woosuk Kwon committed
1029
        assert self.execute_model_state is not None
1030
1031
1032
        hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
            self.execute_model_state
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
1033
1034
        self.execute_model_state = None  # type: ignore

1035
        if not self.is_last_pp_rank:
1036
1037
1038
1039
            # Non-last PP rank: hidden_states is None because this rank produced
            # IntermediateTensors instead of final hidden states. Receive the
            # sampled tokens broadcast from the last rank and update local state.
            sampled, num_sampled, num_rejected = pp_receive(
1040
                input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
1041
            )
1042
            self.postprocess(input_batch, sampled, num_sampled, num_rejected)
1043
1044
1045
            return None

        # Last rank: sample tokens
1046
        sampler_output, num_sampled, num_rejected = self.sample(
1047
            hidden_states, input_batch, grammar_output
Woosuk Kwon's avatar
Woosuk Kwon committed
1048
        )
1049
1050

        if self.use_pp:
1051
            # Broadcast to non-last PP ranks (handles spec decode multi-token).
1052
            pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
1053

1054
1055
1056
1057
        prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
            self.model.compute_logits,
            hidden_states,
            input_batch,
1058
            self.req_states.all_token_ids.gpu,
1059
            self.req_states.num_computed_tokens.gpu,
1060
            self.req_states.prompt_len.np,
1061
1062
1063
            self.req_states.prefill_len.np,
            self.req_states.num_computed_prefill_tokens,
        )
1064
1065
1066
1067
1068
1069
1070
1071

        # Prepare the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            # NOTE(woosuk): req_id_to_index is unused in this model runner.
            # Only for compatibility with the existing model runner and scheduler.
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            sampled_token_ids=None,  # type: ignore
1072
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore[arg-type]
1073
            kv_connector_output=kv_connector_output,
1074
1075
1076
1077
        )
        async_output = AsyncOutput(
            model_runner_output=model_runner_output,
            sampler_output=sampler_output,
1078
            num_sampled_tokens=num_sampled,
1079
            main_stream=self.main_stream,
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )

        # Postprocess results and update request states.
        # NOTE: This is intentionally done after creating the AsyncOutput,
        # ensuring that `copy_event` is recorded before calling postprocess.
        # This sequencing may slightly reduce latency as async D2H copy does not
        # need to wait for the postprocess to finish.
        self.postprocess(
1090
            input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
1091
        )
1092
        if self.speculator is not None:
1093
            draft_tokens = self.speculator.propose(
1094
1095
                input_batch,
                hidden_states,
1096
                aux_hidden_states,
1097
1098
                num_sampled,
                num_rejected,
1099
1100
1101
1102
                self.req_states.last_sampled_tokens,
                self.req_states.next_prefill_tokens,
                self.sampler.sampling_states.temperature.gpu,
                self.sampler.sampling_states.seeds.gpu,
1103
            )
1104
            self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
1105
            self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
1106
1107
1108
1109

        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()
1110
1111
1112

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.draft_tokens_handler.get_draft_tokens()