model_runner.py 44 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
NOTE: Coding style guide for this file:
This model runner is shared by all models: text and multimodal, generative
and embedding, public and private. As a result, this file must only contain
code that is common to every model. Model-specific behavior belongs in the
appropriate model-specific files.

In other words:
* Be paranoid about changing this file. It should remain stable.
* Be even more paranoid about adding new lines. It should remain minimal.

Even for shared features (for example, different parallelism modes), keep the
complexity out of this path. The less common the feature, the more it should be
hidden. Prefer utility functions defined elsewhere and call them from here,
instead of embedding feature-specific logic directly.
"""

20
import functools
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25
26
27
28
29
30
import gc
import time
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
31
from vllm.distributed.parallel_state import (
32
    get_dcp_group,
33
34
35
    get_pp_group,
    prepare_communication_buffer_for_model,
)
36
from vllm.forward_context import BatchDescriptor, set_forward_context
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
39
from vllm.multimodal import MULTIMODAL_REGISTRY
40
from vllm.sequence import IntermediateTensors
41
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
45
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
46
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
47
from vllm.v1.worker.gpu.async_utils import AsyncOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
48
from vllm.v1.worker.gpu.attn_utils import (
49
    build_slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
    get_kv_cache_spec,
    init_attn_backend,
    init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
55
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
56
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
57
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
58
from vllm.v1.worker.gpu.dp_utils import (
59
    get_cudagraph_and_dp_padding,
60
61
    make_num_tokens_across_dp,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
from vllm.v1.worker.gpu.input_batch import (
    InputBatch,
    InputBuffers,
65
    combine_sampled_and_draft_tokens,
66
    expand_idx_mapping,
67
    get_num_sampled_and_rejected,
68
    post_update,
69
70
    prepare_pos_seq_lens,
    prepare_prefill_inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
71
)
72
73
74
75
76
from vllm.v1.worker.gpu.kv_connector import (
    NO_OP_KV_CONNECTOR,
    KVConnector,
    get_kv_connector,
)
77
from vllm.v1.worker.gpu.lora_utils import LoraState
78
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
79
from vllm.v1.worker.gpu.model_states import ModelState
80
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
81
from vllm.v1.worker.gpu.sample.output import SamplerOutput
82
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
83
from vllm.v1.worker.gpu.sample.sampler import Sampler
84
from vllm.v1.worker.gpu.spec_decode import init_speculator
85
86
87
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
    set_eagle3_aux_hidden_state_layers,
)
88
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
89
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
90
from vllm.v1.worker.gpu.states import RequestState
91
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

logger = init_logger(__name__)


97
class GPUModelRunner(LoRAModelRunnerMixin):
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.device = device
        self.dtype = self.model_config.dtype
        self.kv_cache_dtype = self.dtype
        if self.cache_config.cache_dtype != "auto":
            # Quantized KV cache.
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype
            ]
        self.is_pooling_model = False

        self.vocab_size = self.model_config.get_vocab_size()
        self.max_model_len = self.model_config.max_model_len
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.max_num_reqs = self.scheduler_config.max_num_seqs
128
        self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
129

130
        # Multimodal
131
132
133
134
135
136
137
138
139
140
141
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            self.model_config
        )
        if self.supports_mm_inputs:
            self.encoder_runner = EncoderRunner(
                max_num_tokens=self.max_num_tokens,
                hidden_size=self.inputs_embeds_size,
                dtype=self.dtype,
                device=self.device,
            )
142

Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145
146
        self.use_async_scheduling = self.scheduler_config.async_scheduling
        self.output_copy_stream = torch.cuda.Stream(self.device)
        self.output_copy_event = torch.cuda.Event()

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        # Pipeline parallelism.
        self.pp_size = self.parallel_config.pipeline_parallel_size
        self.use_pp = self.pp_size > 1
        if self.use_pp:
            self.is_first_pp_rank = get_pp_group().is_first_rank
            self.is_last_pp_rank = get_pp_group().is_last_rank
        else:
            self.is_first_pp_rank = True
            self.is_last_pp_rank = True

        # Decode context parallelism.
        self.dcp_size = self.parallel_config.decode_context_parallel_size
        self.use_dcp = self.dcp_size > 1
        self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
        self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size

        self.speculator = None
164
        self.num_speculative_steps = 0
165
        self.use_aux_hidden_state_outputs = False
166
167
        if self.speculative_config is not None:
            self.num_speculative_steps = self.speculative_config.num_speculative_tokens
168
169
170
171
172
173
174
175
176
177
178
179
            if self.is_last_pp_rank:
                self.speculator = init_speculator(self.vllm_config, self.device)

            if self.speculative_config.method == "eagle3":
                # EAGLE3 may require auxiliary hidden states from target model outputs.
                self.use_aux_hidden_state_outputs = True
                if self.pp_size > 1:
                    raise ValueError("EAGLE3 with pipeline parallel is not supported.")

        # Draft tokens propagation - for spec-dec + struct outputs.
        self.draft_tokens_handler = DraftTokensHandler(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182
183
        self.req_states = RequestState(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
184
            num_speculative_steps=self.num_speculative_steps,
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
187
188
189
190
191
192
            vocab_size=self.vocab_size,
            device=self.device,
        )
        self.input_buffers = InputBuffers(
            max_num_reqs=self.max_num_reqs,
            max_num_tokens=self.max_num_tokens,
            device=self.device,
        )
193
194
195
196
        self.sampler = Sampler(
            max_num_reqs=self.max_num_reqs,
            vocab_size=self.vocab_size,
            device=self.device,
197
            req_states=self.req_states,
198
            logprobs_mode=self.model_config.logprobs_mode,
199
            num_speculative_tokens=self.num_speculative_steps + 1,
200
        )
201
        self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203

        # CUDA graphs.
204
        self.cudagraph_manager = CudaGraphManager(
205
206
207
            self.vllm_config,
            self.use_aux_hidden_state_outputs,
            self.device,
208
        )
209
210
211
212
        # Structured outputs worker.
        self.structured_outputs_worker = StructuredOutputsWorker(
            max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
            vocab_size=self.vocab_size,
213
            device=self.device,
214
        )
215
216
        # LoRA-related workers.
        self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
217
        # KV Connector if configured.
218
219
        self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

220
221
222
        # For transferring state from execute_model to subsequent sample_tokens call.
        self.execute_model_state: tuple | None = None

223
224
225
226
    def update_max_model_len(self, max_model_len: int) -> None:
        self.max_model_len = max_model_len
        self.req_states.max_model_len = max_model_len

227
228
    @staticmethod
    def get_supported_tasks() -> tuple[str]:
Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        return ("generate",)

    def load_model(self, *args, **kwargs) -> None:
        time_before_load = time.perf_counter()
        with DeviceMemoryProfiler() as m:
            model_loader = get_model_loader(self.vllm_config.load_config)
            logger.info("Loading model from scratch...")

            self.model = model_loader.load_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.model_config,
            )
            if self.lora_config:
                self.model = self.load_lora_model(
243
                    self.model, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
244
                )
245
246
247
248
249

            if self.use_aux_hidden_state_outputs:
                assert self.speculative_config is not None
                set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
            if self.speculator is not None:
250
                self.speculator.load_model(self.model)
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254
        time_after_load = time.perf_counter()

        self.model_memory_usage = m.consumed_memory
        logger.info(
255
256
            "Model loading took %s GiB and %.6f seconds",
            format_gib(m.consumed_memory),
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
            time_after_load - time_before_load,
        )

260
        prepare_communication_buffer_for_model(self.model)
261
262
        if self.speculator is not None:
            prepare_communication_buffer_for_model(self.speculator)
263

264
265
266
        # Initialize the components that require the model.
        self.model_state = ModelState(self.vllm_config, self.model, self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
    def get_model(self) -> nn.Module:
        return self.model

270
271
272
273
274
    @functools.cached_property
    def main_stream(self) -> torch.cuda.Stream:
        # Cache the default CUDA stream to avoid lookup overhead.
        return torch.cuda.current_stream(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    def get_kv_cache_spec(self):
        return get_kv_cache_spec(self.vllm_config)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        kv_cache_config = deepcopy(kv_cache_config)
        self.kv_cache_config = kv_cache_config
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]

        self.block_tables = BlockTables(
            block_sizes=block_sizes,
            max_num_reqs=self.max_num_reqs,
            max_num_batched_tokens=self.max_num_tokens,
            max_model_len=self.max_model_len,
            device=self.device,
292
293
294
            cp_size=self.dcp_size,
            cp_rank=self.dcp_rank,
            cp_interleave=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296
        )

297
        self.attn_backends, self.attn_groups = init_attn_backend(
298
            self.kv_cache_config, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
299
        )
300
        check_attention_cp_compatibility(self.vllm_config)
301
        if self.speculator is not None:
302
303
304
            # HACK(woosuk)
            self.speculator.set_attn(
                self.kv_cache_config,
305
                self.attn_groups,
306
307
                self.block_tables,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309

        self.kv_caches: list[torch.Tensor] = []
310
        kv_caches_dict = init_kv_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
311
312
313
314
315
316
            self.kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_cache_config,
            self.attn_backends,
            self.device,
        )
317
318
        self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)

Woosuk Kwon's avatar
Woosuk Kwon committed
319
320
    @torch.inference_mode()
    def _dummy_run(
321
        self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
322
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
323
        # Create a dummy scheduler output.
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        num_reqs = min(num_tokens, self.max_num_reqs)
325
326
327
328
        num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
        num_tokens_per_request[-1] += num_tokens % num_reqs
        assert sum(num_tokens_per_request) == num_tokens
        num_scheduled_tokens = {
329
            f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
330
331
332
333
334
        }
        dummy_scheduler_output = SchedulerOutput.make_empty()
        dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
        dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens

335
336
337
        # Disable any use of KVConnector for dummy runs.
        self.kv_connector.set_disabled(True)

338
339
        # For non-first PP ranks, create dummy intermediate_tensors.
        intermediate_tensors = None
340
        if not self.is_first_pp_rank:
341
342
343
344
345
346
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=num_tokens,
                dtype=self.model_config.dtype,
                device=self.device,
            )

347
348
        # Execute the model.
        self.execute_model(
349
350
351
352
            dummy_scheduler_output,
            intermediate_tensors=intermediate_tensors,
            dummy_run=True,
            skip_attn_for_dummy_run=skip_attn,
353
        )
354
        self.kv_connector.set_disabled(False)
355
356

        # Non-last PP ranks don't produce output for sampling.
357
        if not self.is_last_pp_rank:
358
359
            return None, None

360
        assert self.execute_model_state is not None
361
        input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state
362
        self.execute_model_state = None
363
        assert hidden_states is not None  # Last PP rank always has hidden_states
364
        sample_hidden_states = hidden_states[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
365
366
367
        return hidden_states, sample_hidden_states

    @torch.inference_mode()
368
    def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
369
370
        num_reqs = hidden_states.shape[0]
        logits = self.model.compute_logits(hidden_states)
371
372
373
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
        idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
        pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
374
375
376
377
        dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
        expanded_local_pos = torch.zeros(
            num_reqs, dtype=torch.int32, device=self.device
        )
378
379
380
        # NOTE(woosuk): During the initial memory profiling, the sampler may skip
        # top_k, top_p, and logprobs, using less GPU memory than what is possible
        # during actual execution.
381
382
383
384
385
386
387
388
389
        self.sampler(
            logits,
            idx_mapping,
            idx_mapping_np,
            idx_mapping_np,
            pos,
            dummy_input_ids,
            expanded_local_pos,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
390
391
392
393

    @torch.inference_mode()
    def profile_run(self) -> None:
        hidden_states, sample_hidden_states = self._dummy_run(
394
            self.max_num_tokens, skip_attn=True
Woosuk Kwon's avatar
Woosuk Kwon committed
395
        )
396

397
        # Only run sampler on last PP rank (non-last ranks return None).
398
        if self.is_last_pp_rank:
399
400
            assert sample_hidden_states is not None
            self._dummy_sampler_run(sample_hidden_states)
401

402
            if self.speculator is not None:
403
404
405
406
407
408
409
410
411
412
                num_tokens_across_dp = make_num_tokens_across_dp(
                    self.parallel_config.data_parallel_size, self.max_num_tokens
                )
                self.speculator.run_model(
                    self.max_num_tokens,
                    attn_metadata=None,
                    slot_mappings=None,
                    num_tokens_across_dp=num_tokens_across_dp,
                )

Woosuk Kwon's avatar
Woosuk Kwon committed
413
414
415
416
417
        torch.cuda.synchronize()
        del hidden_states, sample_hidden_states
        gc.collect()

    def reset_mm_cache(self) -> None:
418
419
        if self.supports_mm_inputs:
            self.encoder_runner.reset_mm_cache()
420
421

    def reset_encoder_cache(self) -> None:
422
423
        if self.supports_mm_inputs:
            self.encoder_runner.reset_encoder_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
424
425
426
427
428
429
430
431
432
433
434
435
436
437

    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
        # SP is not supported yet.
        return num_scheduled_tokens

    @torch.inference_mode()
    def capture_model(self) -> int:
        if not self.cudagraph_manager.needs_capture():
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
            return 0

438
439
440
441
442
443
444
445
        # TODO (zhanqiu): support CUDA graph for PP.
        if self.use_pp:
            logger.warning_once(
                "Skipping CUDA graph capture because pipeline parallel is "
                "enabled. Pipeline parallel is currently eager-only.",
            )
            return 0

Woosuk Kwon's avatar
Woosuk Kwon committed
446
        start_time = time.perf_counter()
447
        gc.collect()
448
        torch.cuda.empty_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
449
450
451
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        with self.maybe_setup_dummy_loras(self.lora_config):
452
453
454
            inputs_embeds = None
            if self.supports_mm_inputs:
                inputs_embeds = self.encoder_runner.inputs_embeds
Woosuk Kwon's avatar
Woosuk Kwon committed
455
456
            self.cudagraph_manager.capture(
                model=self.model,
457
                model_state=self.model_state,
Woosuk Kwon's avatar
Woosuk Kwon committed
458
                input_buffers=self.input_buffers,
459
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
460
                block_tables=self.block_tables,
461
                attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
462
                kv_cache_config=self.kv_cache_config,
463
                has_lora=self.lora_config is not None,
Woosuk Kwon's avatar
Woosuk Kwon committed
464
            )
465
            if self.speculator is not None:
466
                self.speculator.capture_model()
Woosuk Kwon's avatar
Woosuk Kwon committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info(
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
        )
        return cuda_graph_size

    def warmup_for_prefill(self) -> None:
        # For FlashInfer, we would like to execute a dummy prefill run
        # to trigger JIT compilation.
        if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
            self._dummy_run(self.max_num_tokens, skip_attn=False)
            torch.cuda.synchronize()

487
    def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
488
        finished_req_ids = scheduler_output.finished_req_ids
489
490
491
        preempted_req_ids = scheduler_output.preempted_req_ids
        if preempted_req_ids:
            finished_req_ids = finished_req_ids.union(preempted_req_ids)
492
        for req_id in finished_req_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
493
            self.req_states.remove_request(req_id)
494
495
            if self.supports_mm_inputs:
                self.encoder_runner.remove_request(req_id)
496
            self.prompt_logprobs_worker.remove_request(req_id)
497
            self.lora_state.remove_request(req_id)
498

499
    def free_states(self, scheduler_output: SchedulerOutput) -> None:
500
501
502
        if self.supports_mm_inputs:
            for mm_hash in scheduler_output.free_encoder_mm_hashes:
                self.encoder_runner.free_encoder_cache(mm_hash)
Woosuk Kwon's avatar
Woosuk Kwon committed
503

504
    def add_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
505
        for new_req_data in scheduler_output.scheduled_new_reqs:
506
507
508
            assert new_req_data.prompt_token_ids is not None
            assert new_req_data.prefill_token_ids is not None
            assert new_req_data.sampling_params is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
509
            req_id = new_req_data.req_id
510
            prompt_len = len(new_req_data.prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
511
512
            self.req_states.add_request(
                req_id=req_id,
513
                prompt_len=prompt_len,
514
                all_token_ids=new_req_data.prefill_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
515
516
517
                num_computed_tokens=new_req_data.num_computed_tokens,
            )
            req_index = self.req_states.req_id_to_index[req_id]
518

519
520
521
            if self.supports_mm_inputs:
                self.encoder_runner.add_request(req_id, new_req_data.mm_features)

522
            self.model_state.add_request(req_index, new_req_data)
523
524
525
            self.block_tables.append_block_ids(
                req_index, new_req_data.block_ids, overwrite=True
            )
526
527
528
            self.sampler.add_request(
                req_index, prompt_len, new_req_data.sampling_params
            )
529
530
531
            self.prompt_logprobs_worker.add_request(
                req_id, req_index, new_req_data.sampling_params
            )
532
            self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
Woosuk Kwon's avatar
Woosuk Kwon committed
533

534
535
        if scheduler_output.scheduled_new_reqs:
            self.req_states.apply_staged_writes()
536
            self.sampler.apply_staged_writes()
537
            self.model_state.apply_staged_writes()
538
539

    def update_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
540
        # Add new blocks for the existing requests.
541
542
        reqs = scheduler_output.scheduled_cached_reqs
        for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
543
            if req_new_block_ids is not None:
544
                req_index = self.req_states.req_id_to_index[req_id]
545
546
547
                self.block_tables.append_block_ids(
                    req_index, req_new_block_ids, overwrite=False
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
548
549

    def prepare_inputs(
550
        self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
Woosuk Kwon's avatar
Woosuk Kwon committed
551
552
553
    ) -> InputBatch:
        num_tokens = scheduler_output.total_num_scheduled_tokens
        assert num_tokens > 0
554
555
        num_tokens_per_req = scheduler_output.num_scheduled_tokens
        num_reqs = len(num_tokens_per_req)
Woosuk Kwon's avatar
Woosuk Kwon committed
556
557
558

        # Decode first, then prefill.
        # batch_idx -> req_id
559
560
561
        req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get)  # type: ignore[arg-type]
        numtoks_iter = map(num_tokens_per_req.get, req_ids)
        num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
562

563
564
        idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
        idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
565
        idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
566

567
        # Get the number of draft tokens for each request.
568
569
        draft_tokens = scheduler_output.scheduled_spec_decode_tokens
        if not draft_tokens:
570
571
572
            # No draft token scheduled (common case).
            total_num_draft_tokens = 0
            total_num_logits = num_reqs
573
            cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
574
575
576
            cu_num_logits = torch.arange(
                num_reqs + 1, device=self.device, dtype=torch.int32
            )
577
            expanded_idx_mapping = idx_mapping
578
579
580
            expanded_local_pos = torch.zeros(
                num_reqs, dtype=torch.int32, device=self.device
            )
581
582
        else:
            num_draft_tokens = np.array(
583
                [len(draft_tokens.get(req_id, ())) for req_id in req_ids],
584
585
586
587
588
                dtype=np.int32,
            )
            total_num_draft_tokens = int(num_draft_tokens.sum())
            total_num_logits = num_reqs + total_num_draft_tokens

589
590
591
592
            num_logits = num_draft_tokens + 1
            cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
            cu_num_logits_np[0] = 0
            np.cumsum(num_logits, out=cu_num_logits_np[1:])
593
            cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
594

595
            max_expand_len = self.num_speculative_steps + 1
596
            expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
597
                idx_mapping, total_num_logits, cu_num_logits, max_expand_len
598
599
            )

600
        # Get query_start_loc.
601
602
603
        query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
        query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
604
605
        # Pad for full CUDA graph mode.
        # Some attention backends like FA3 require query_start_loc to be non-decreasing.
606
        query_start_loc_np[num_reqs + 1 :] = num_tokens
607
        async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
608
609
        query_start_loc_np = query_start_loc_np[: num_reqs + 1]
        query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
610

611
612
613
614
615
616
617
618
619
620
621
        # Get prefill tokens if any.
        if self.req_states.any_prefills(idx_mapping_np):
            prepare_prefill_inputs(
                self.input_buffers.input_ids,
                self.req_states.next_prefill_tokens,
                idx_mapping,
                query_start_loc,
                self.req_states.all_token_ids.gpu,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
622

623
624
625
        # Prepare positions and seq_lens.
        prepare_pos_seq_lens(
            idx_mapping,
626
627
            query_start_loc,
            self.req_states.num_computed_tokens.gpu,
628
629
630
631
632
            self.input_buffers.positions,
            self.input_buffers.seq_lens,
        )
        seq_lens = self.input_buffers.seq_lens[:num_reqs]

633
        dcp_local_seq_lens = None
634
635
        if self.use_dcp:
            # Prepare dcp local seq_lens.
636
637
            prepare_dcp_local_seq_lens(
                self.input_buffers.dcp_local_seq_lens,
638
                self.input_buffers.seq_lens,
639
                num_reqs,
640
641
642
                self.dcp_size,
                self.dcp_rank,
                self.cp_interleave,
643
            )
644
            dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
645

646
        # Some input token ids are directly read from the last sampled tokens
647
648
        # and draft tokens. Also, get the logits indices to sample tokens from.
        logits_indices = combine_sampled_and_draft_tokens(
649
            self.input_buffers.input_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
650
651
            idx_mapping,
            self.req_states.last_sampled_tokens,
652
            query_start_loc,
653
654
            seq_lens,
            self.req_states.prefill_len.gpu,
655
656
657
            self.req_states.draft_tokens,
            cu_num_logits,
            total_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
658
659
660
661
662
663
664
        )

        return InputBatch(
            req_ids=req_ids,
            num_reqs=num_reqs,
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
665
            expanded_idx_mapping=expanded_idx_mapping,
666
            expanded_local_pos=expanded_local_pos,
Woosuk Kwon's avatar
Woosuk Kwon committed
667
668
669
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens_after_padding,
670
            num_draft_tokens=total_num_draft_tokens,
671
            query_start_loc=query_start_loc,
Woosuk Kwon's avatar
Woosuk Kwon committed
672
            query_start_loc_np=query_start_loc_np,
673
            seq_lens=seq_lens,
674
675
676
            dcp_local_seq_lens=dcp_local_seq_lens,
            input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
            positions=self.input_buffers.positions[:num_tokens_after_padding],
677
            inputs_embeds=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
678
            logits_indices=logits_indices,
679
            cu_num_logits=cu_num_logits,
680
            cu_num_logits_np=cu_num_logits_np,
681
            has_structured_output_reqs=scheduler_output.has_structured_output_requests,
Woosuk Kwon's avatar
Woosuk Kwon committed
682
683
        )

684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    def prepare_attn(
        self, input_batch: InputBatch
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
        # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
        block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
        # Compute slot mappings: [num_kv_cache_groups, num_tokens]
        slot_mappings = self.block_tables.compute_slot_mappings(
            input_batch.idx_mapping,
            input_batch.query_start_loc,
            input_batch.positions,
        )
        return block_tables, slot_mappings

    def prepare_dummy_attn(
        self, input_batch: InputBatch
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
        block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
        slot_mappings = self.block_tables.get_dummy_slot_mappings(
            input_batch.num_tokens
        )
        return block_tables, slot_mappings

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
    @torch.inference_mode()
    def get_mm_embeddings(
        self,
        scheduled_encoder_inputs: dict[str, list[int]],
        input_batch: InputBatch,
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
            scheduled_encoder_inputs
        )
        self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
        mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
            input_batch.req_ids,
            input_batch.num_tokens,
            input_batch.num_scheduled_tokens,
            input_batch.query_start_loc_np,
            self.req_states.prefill_len.np[input_batch.idx_mapping_np],
            self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
        )
        return mm_embeds, is_mm_embed

Woosuk Kwon's avatar
Woosuk Kwon committed
726
727
728
729
730
    def sample(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
        grammar_output: GrammarOutput | None,
731
    ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
732
        sample_hidden_states = hidden_states[input_batch.logits_indices]
733
        sample_pos = input_batch.positions[input_batch.logits_indices]
734
        input_ids = input_batch.input_ids[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
735
736
737
        logits = self.model.compute_logits(sample_hidden_states)
        if grammar_output is not None:
            # Apply grammar bitmask to the logits in-place.
738
739
740
741
742
743
            self.structured_outputs_worker.apply_grammar_bitmask(
                logits,
                input_batch,
                grammar_output.structured_output_request_ids,
                grammar_output.grammar_bitmask,
            )
744

745
        # Sample tokens and compute logprobs (if needed).
746
747
748
749
        sampler_output = self.sampler(
            logits,
            input_batch.expanded_idx_mapping,
            input_batch.idx_mapping_np,
750
            input_batch.cu_num_logits_np,
751
            sample_pos,
752
753
            input_ids,
            input_batch.expanded_local_pos,
754
        )
755
756
757

        if input_batch.num_draft_tokens == 0:
            # No draft tokens (common case).
758
759
760
            num_sampled = torch.ones(
                input_batch.num_reqs, dtype=torch.int32, device=self.device
            )
761
        else:
762
            # Rejection sampling for spec decoding.
763
764
765
766
767
768
769
            sampled_tokens, num_sampled = rejection_sample(
                sampler_output.sampled_token_ids,
                input_ids,
                input_batch.cu_num_logits,
                self.num_speculative_steps,
            )
            sampler_output.sampled_token_ids = sampled_tokens
770
771
772
773
774
775
776
777
778
779

        # Get the number of sampled and rejected tokens.
        # For chunked prefills, num_sampled and num_rejected are both 0.
        num_sampled, num_rejected = get_num_sampled_and_rejected(
            num_sampled,
            input_batch.seq_lens,
            input_batch.cu_num_logits,
            input_batch.idx_mapping,
            self.req_states.prefill_len.gpu,
        )
780
        return sampler_output, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
781
782
783
784

    def postprocess(
        self,
        input_batch: InputBatch,
785
786
        sampled_tokens: torch.Tensor,
        num_sampled: torch.Tensor,
787
        num_rejected: torch.Tensor,
788
789
    ) -> None:
        # Update the number of computed tokens.
790
        post_update(
791
            input_batch.idx_mapping,
792
            self.req_states.num_computed_tokens.gpu,
793
            self.req_states.last_sampled_tokens,
794
            self.sampler.penalties_state.output_bin_counts,
795
796
            sampled_tokens,
            num_sampled,
797
            num_rejected,
798
            input_batch.query_start_loc,
799
800
            self.req_states.all_token_ids.gpu,
            self.req_states.total_len.gpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
801
        )
802
803

        # Update the number of computed prefill tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
804
        idx_mapping_np = input_batch.idx_mapping_np
805
        computed_prefill = self.req_states.num_computed_prefill_tokens
806
807
808
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
Woosuk Kwon's avatar
Woosuk Kwon committed
809
810
811
812
813
814
        )

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: SchedulerOutput,
815
        intermediate_tensors: IntermediateTensors | None = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
816
        dummy_run: bool = False,
817
        skip_attn_for_dummy_run: bool = False,
818
    ) -> ModelRunnerOutput | IntermediateTensors | None:
819
820
821
822
823
824
825
826
827
        if not dummy_run:
            # Update the request states.
            self.finish_requests(scheduler_output)
            self.free_states(scheduler_output)
            self.add_requests(scheduler_output)
            self.update_requests(scheduler_output)
            self.block_tables.apply_staged_writes()
            if scheduler_output.total_num_scheduled_tokens == 0:
                # No need to run the model.
828
829
                empty_output = self.kv_connector.no_forward(scheduler_output)
                return empty_output
Woosuk Kwon's avatar
Woosuk Kwon committed
830

831
832
833
834
835
836
837
        # Get local cudagraph mode and size.
        local_cudagraph_mode, local_cudagraph_size = (
            self.cudagraph_manager.get_cudagraph_runtime_mode(
                num_reqs=len(scheduler_output.num_scheduled_tokens),
                num_tokens=scheduler_output.total_num_scheduled_tokens,
                max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
            )
838
        )
839
840
841

        # DP sync: num_tokens + cudagraph_size + cudagraph_mode
        num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
842
            get_cudagraph_and_dp_padding(
843
                scheduler_output.total_num_scheduled_tokens,
844
845
                local_cudagraph_size,
                local_cudagraph_mode.value,
846
847
                self.parallel_config.data_parallel_size,
                self.parallel_config.data_parallel_rank,
848
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
849
        )
850
        cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
851
852
        if num_tokens_after_padding == 0:
            # All DP ranks have zero tokens to run.
853
854
            empty_output = self.kv_connector.no_forward(scheduler_output)
            return empty_output
855
856
857
858
859

        if not dummy_run:
            # Common case.
            # Prepare all the inputs and copy to the input buffers.
            input_batch = self.prepare_inputs(
860
                scheduler_output, num_tokens_after_padding
861
            )
862
863
            block_tables, slot_mappings = self.prepare_attn(input_batch)

864
865
            if self.lora_config:
                # Activate LoRA adapters.
866
                lora_inputs = self.lora_state.make_lora_inputs(
867
868
869
                    input_batch.req_ids,
                    input_batch.idx_mapping_np,
                    input_batch.num_scheduled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
870
                )
871
                self._set_active_loras(*lora_inputs)
872

873
            # Only first PP rank prepares multimodal embeddings.
874
            if self.supports_mm_inputs and self.is_first_pp_rank:
875
876
877
878
879
880
881
882
883
                mm_embeds, is_mm_embed = self.get_mm_embeddings(
                    scheduler_output.scheduled_encoder_inputs, input_batch
                )
                inputs_embeds = self.encoder_runner.get_inputs_embeds(
                    self.model, input_batch.input_ids, mm_embeds, is_mm_embed
                )
                input_batch.inputs_embeds = inputs_embeds[
                    : input_batch.num_tokens_after_padding
                ]
884
        else:
885
            # No actual tokens to run. A dummy run for DP or memory profiling.
886
887
888
889
890
891
892
            num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
            input_batch = InputBatch.make_dummy(
                num_reqs=num_reqs,
                num_tokens=num_tokens_after_padding,
                input_buffers=self.input_buffers,
                device=self.device,
            )
893
            if not skip_attn_for_dummy_run:
894
895
896
897
                block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
            else:
                block_tables = None
                slot_mappings = None
898
            # FIXME(woosuk): Fix warmup for LoRA.
Woosuk Kwon's avatar
Woosuk Kwon committed
899

900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
        attn_metadata = None
        slot_mappings_by_layer = None
        if not (dummy_run and skip_attn_for_dummy_run):
            assert slot_mappings is not None
            slot_mappings_by_layer = build_slot_mappings_by_layer(
                slot_mappings, self.kv_cache_config
            )
            assert block_tables is not None
            attn_metadata = self.model_state.prepare_attn(
                input_batch,
                block_tables,
                slot_mappings,
                self.attn_groups,
                self.kv_cache_config,
            )

916
917
918
919
920
921
922
923
924
925
926
927
928
929
        model_inputs = {
            "input_ids": input_batch.input_ids,
            "positions": input_batch.positions,
            "inputs_embeds": input_batch.inputs_embeds,
            # NOTE: Values returned by `prepare_inputs` will override the default
            # values above.
            **self.model_state.prepare_inputs(input_batch, self.req_states),
        }
        if not self.is_first_pp_rank:
            # Update for non-first PP ranks.
            model_inputs["input_ids"] = None
            model_inputs["inputs_embeds"] = None
            model_inputs["intermediate_tensors"] = intermediate_tensors

Woosuk Kwon's avatar
Woosuk Kwon committed
930
        # Run model.
931
932
        if cudagraph_runtime_mode == CUDAGraphMode.FULL:
            # Use explicit cudagraph replay for FULL mode.
Woosuk Kwon's avatar
Woosuk Kwon committed
933
934
            # NOTE(woosuk): Here, we don't need to pass the input tensors,
            # because they are already copied to the CUDA graph input buffers.
935
            self.kv_connector.pre_forward(scheduler_output)
936
            model_output = self.cudagraph_manager.run_fullgraph(
Woosuk Kwon's avatar
Woosuk Kwon committed
937
938
                input_batch.num_tokens_after_padding
            )
939
940
941
942
943
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
944
        else:
945
946
947
948
949
950
            # For piecewise and eager mode, just call model().
            batch_descriptor = BatchDescriptor(
                num_tokens=input_batch.num_tokens_after_padding,
                has_lora=self.lora_config is not None,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
951
            with set_forward_context(
952
                attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
953
954
                self.vllm_config,
                num_tokens=input_batch.num_tokens_after_padding,
955
                cudagraph_runtime_mode=cudagraph_runtime_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
956
                num_tokens_across_dp=num_tokens_across_dp,
957
                batch_descriptor=batch_descriptor,
958
                slot_mapping=slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
959
            ):
960
                self.kv_connector.pre_forward(scheduler_output)
961
                model_output = self.model(**model_inputs)
962
963
964
965
966
                if self.use_aux_hidden_state_outputs:
                    hidden_states, aux_hidden_states = model_output
                else:
                    hidden_states = model_output
                    aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
967

968
        kv_connector_output = self.kv_connector.post_forward(scheduler_output)
969
970
971
972
973
974
975
976
977
        self.execute_model_state = (
            input_batch,
            model_inputs,
            attn_metadata,
            slot_mappings_by_layer,
            hidden_states,
            aux_hidden_states,
            kv_connector_output,
        )
978

979
        if not self.is_last_pp_rank:
980
981
982
983
984
            # Non-last PP rank: return IntermediateTensors for sending.
            assert isinstance(hidden_states, IntermediateTensors)
            hidden_states.kv_connector_output = kv_connector_output
            return hidden_states
        # Last rank (or no PP): hidden_states is a tensor for sampling.
985
        assert isinstance(hidden_states, torch.Tensor)
Woosuk Kwon's avatar
Woosuk Kwon committed
986
987
988
989
        return None

    @torch.inference_mode()
    def sample_tokens(
990
        self, grammar_output: GrammarOutput | None
991
    ) -> AsyncOutput | ModelRunnerOutput | None:
992
993
994
        if self.execute_model_state is None:
            # The prior execute_model call must have failed.
            return None
995
996
997
998
999
1000
1001
1002
1003
        (
            input_batch,
            model_inputs,
            attn_metadata,
            slot_mappings_by_layer,
            hidden_states,
            aux_hidden_states,
            kv_connector_output,
        ) = self.execute_model_state
1004
        self.execute_model_state = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1005

1006
        if not self.is_last_pp_rank:
1007
1008
1009
1010
            # Non-last PP rank: hidden_states is None because this rank produced
            # IntermediateTensors instead of final hidden states. Receive the
            # sampled tokens broadcast from the last rank and update local state.
            sampled, num_sampled, num_rejected = pp_receive(
1011
                input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
1012
            )
1013
            self.postprocess(input_batch, sampled, num_sampled, num_rejected)
1014
1015
1016
            return None

        # Last rank: sample tokens
1017
        sampler_output, num_sampled, num_rejected = self.sample(
1018
            hidden_states, input_batch, grammar_output
Woosuk Kwon's avatar
Woosuk Kwon committed
1019
        )
1020
1021

        if self.use_pp:
1022
            # Broadcast to non-last PP ranks (handles spec decode multi-token).
1023
            pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
1024

1025
1026
1027
1028
        prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
            self.model.compute_logits,
            hidden_states,
            input_batch,
1029
            self.req_states.all_token_ids.gpu,
1030
            self.req_states.num_computed_tokens.gpu,
1031
            self.req_states.prompt_len.np,
1032
1033
1034
            self.req_states.prefill_len.np,
            self.req_states.num_computed_prefill_tokens,
        )
1035
1036
1037
1038
1039
1040
1041
1042

        # Prepare the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            # NOTE(woosuk): req_id_to_index is unused in this model runner.
            # Only for compatibility with the existing model runner and scheduler.
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            sampled_token_ids=None,  # type: ignore
1043
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore[arg-type]
1044
            kv_connector_output=kv_connector_output,
1045
1046
1047
1048
        )
        async_output = AsyncOutput(
            model_runner_output=model_runner_output,
            sampler_output=sampler_output,
1049
            num_sampled_tokens=num_sampled,
1050
            main_stream=self.main_stream,
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )

        # Postprocess results and update request states.
        # NOTE: This is intentionally done after creating the AsyncOutput,
        # ensuring that `copy_event` is recorded before calling postprocess.
        # This sequencing may slightly reduce latency as async D2H copy does not
        # need to wait for the postprocess to finish.
        self.postprocess(
1061
            input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
1062
        )
1063
        if self.speculator is not None:
1064
            draft_tokens = self.speculator.propose(
1065
                input_batch,
1066
1067
                attn_metadata,
                slot_mappings_by_layer,
1068
                hidden_states,
1069
                aux_hidden_states,
1070
1071
                num_sampled,
                num_rejected,
1072
1073
1074
1075
                self.req_states.last_sampled_tokens,
                self.req_states.next_prefill_tokens,
                self.sampler.sampling_states.temperature.gpu,
                self.sampler.sampling_states.seeds.gpu,
1076
            )
1077
            self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
1078
            self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
1079
1080
1081
1082

        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()
1083
1084
1085

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.draft_tokens_handler.get_draft_tokens()